diff options
| author | 韩丁康 <[email protected]> | 2023-12-18 15:40:11 +0800 |
|---|---|---|
| committer | 韩丁康 <[email protected]> | 2023-12-18 15:40:11 +0800 |
| commit | 34c9811440252364c9cbd7acb7a5110e906d8aa5 (patch) | |
| tree | 02e7c853aa40be78f536ae9392b4ce093515d3b6 /plugin | |
| parent | e9faceecc3729727f5d6b552a6335b80e8922be6 (diff) | |
初始化提交
Diffstat (limited to 'plugin')
607 files changed, 64376 insertions, 0 deletions
diff --git a/plugin/acl/README.md b/plugin/acl/README.md new file mode 100644 index 0000000..d957d24 --- /dev/null +++ b/plugin/acl/README.md @@ -0,0 +1,110 @@ +# acl + +## Name + +*acl* - enforces access control policies on source ip and prevents unauthorized access to DNS servers. + +## Description + +With `acl` enabled, users are able to block or filter suspicious DNS queries by configuring IP filter rule sets, i.e. allowing authorized queries or blocking unauthorized queries. + + +When evaluating the rule sets, _acl_ uses the source IP of the TCP/UDP headers of the DNS query received by CoreDNS. +This source IP will be different than the IP of the client originating the request in cases where the source IP of the request is changed in transit. For example: +* if the request passes though an intermediate forwarding DNS server or recursive DNS server before reaching CoreDNS +* if the request traverses a Source NAT before reaching CoreDNS + +This plugin can be used multiple times per Server Block. + +## Syntax + +``` +acl [ZONES...] { + ACTION [type QTYPE...] [net SOURCE...] +} +``` + +- **ZONES** zones it should be authoritative for. If empty, the zones from the configuration block are used. +- **ACTION** (*allow*, *block*, *filter*, or *drop*) defines the way to deal with DNS queries matched by this rule. The default action is *allow*, which means a DNS query not matched by any rules will be allowed to recurse. The difference between *block* and *filter* is that block returns status code of *REFUSED* while filter returns an empty set *NOERROR*. *drop* however returns no response to the client. +- **QTYPE** is the query type to match for the requests to be allowed or blocked. Common resource record types are supported. `*` stands for all record types. The default behavior for an omitted `type QTYPE...` is to match all kinds of DNS queries (same as `type *`). +- **SOURCE** is the source IP address to match for the requests to be allowed or blocked. Typical CIDR notation and single IP address are supported. `*` stands for all possible source IP addresses. + +## Examples + +To demonstrate the usage of plugin acl, here we provide some typical examples. + +Block all DNS queries with record type A from 192.168.0.0/16: + +~~~ corefile +. { + acl { + block type A net 192.168.0.0/16 + } +} +~~~ + +Filter all DNS queries with record type A from 192.168.0.0/16: + +~~~ corefile +. { + acl { + filter type A net 192.168.0.0/16 + } +} +~~~ + +Block all DNS queries from 192.168.0.0/16 except for 192.168.1.0/24: + +~~~ corefile +. { + acl { + allow net 192.168.1.0/24 + block net 192.168.0.0/16 + } +} +~~~ + +Allow only DNS queries from 192.168.0.0/24 and 192.168.1.0/24: + +~~~ corefile +. { + acl { + allow net 192.168.0.0/24 192.168.1.0/24 + block + } +} +~~~ + +Block all DNS queries from 192.168.1.0/24 towards a.example.org: + +~~~ corefile +example.org { + acl a.example.org { + block net 192.168.1.0/24 + } +} +~~~ + +Drop all DNS queries from 192.0.2.0/24: + +~~~ corefile +. { + acl { + drop net 192.0.2.0/24 + } +} +~~~ + +## Metrics + +If monitoring is enabled (via the _prometheus_ plugin) then the following metrics are exported: + +- `coredns_acl_blocked_requests_total{server, zone, view}` - counter of DNS requests being blocked. + +- `coredns_acl_filtered_requests_total{server, zone, view}` - counter of DNS requests being filtered. + +- `coredns_acl_allowed_requests_total{server, view}` - counter of DNS requests being allowed. + +- `coredns_acl_dropped_requests_total{server, zone, view}` - counter of DNS requests being dropped. + +The `server` and `zone` labels are explained in the _metrics_ plugin documentation. diff --git a/plugin/acl/acl.go b/plugin/acl/acl.go new file mode 100644 index 0000000..2632326 --- /dev/null +++ b/plugin/acl/acl.go @@ -0,0 +1,151 @@ +package acl + +import ( + "context" + "net" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/infobloxopen/go-trees/iptree" + "github.com/miekg/dns" +) + +// ACL enforces access control policies on DNS queries. +type ACL struct { + Next plugin.Handler + + Rules []rule +} + +// rule defines a list of Zones and some ACL policies which will be +// enforced on them. +type rule struct { + zones []string + policies []policy +} + +// action defines the action against queries. +type action int + +// policy defines the ACL policy for DNS queries. +// A policy performs the specified action (block/allow) on all DNS queries +// matched by source IP or QTYPE. +type policy struct { + action action + qtypes map[uint16]struct{} + filter *iptree.Tree +} + +const ( + // actionNone does nothing on the queries. + actionNone = iota + // actionAllow allows authorized queries to recurse. + actionAllow + // actionBlock blocks unauthorized queries towards protected DNS zones. + actionBlock + // actionFilter returns empty sets for queries towards protected DNS zones. + actionFilter + // actionDrop does not respond for queries towards the protected DNS zones. + actionDrop +) + +var log = clog.NewWithPlugin("acl") + +// ServeDNS implements the plugin.Handler interface. +func (a ACL) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + +RulesCheckLoop: + for _, rule := range a.Rules { + // check zone. + zone := plugin.Zones(rule.zones).Matches(state.Name()) + if zone == "" { + continue + } + + action := matchWithPolicies(rule.policies, w, r) + switch action { + case actionDrop: + { + RequestDropCount.WithLabelValues(metrics.WithServer(ctx), zone, metrics.WithView(ctx)).Inc() + return dns.RcodeSuccess, nil + } + case actionBlock: + { + m := new(dns.Msg). + SetRcode(r, dns.RcodeRefused). + SetEdns0(4096, true) + ede := dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeBlocked} + m.IsEdns0().Option = append(m.IsEdns0().Option, &ede) + w.WriteMsg(m) + RequestBlockCount.WithLabelValues(metrics.WithServer(ctx), zone, metrics.WithView(ctx)).Inc() + return dns.RcodeSuccess, nil + } + case actionAllow: + { + break RulesCheckLoop + } + case actionFilter: + { + m := new(dns.Msg). + SetRcode(r, dns.RcodeSuccess). + SetEdns0(4096, true) + ede := dns.EDNS0_EDE{InfoCode: dns.ExtendedErrorCodeFiltered} + m.IsEdns0().Option = append(m.IsEdns0().Option, &ede) + w.WriteMsg(m) + RequestFilterCount.WithLabelValues(metrics.WithServer(ctx), zone, metrics.WithView(ctx)).Inc() + return dns.RcodeSuccess, nil + } + } + } + + RequestAllowCount.WithLabelValues(metrics.WithServer(ctx), metrics.WithView(ctx)).Inc() + return plugin.NextOrFailure(state.Name(), a.Next, ctx, w, r) +} + +// matchWithPolicies matches the DNS query with a list of ACL polices and returns suitable +// action against the query. +func matchWithPolicies(policies []policy, w dns.ResponseWriter, r *dns.Msg) action { + state := request.Request{W: w, Req: r} + + var ip net.IP + if idx := strings.IndexByte(state.IP(), '%'); idx >= 0 { + ip = net.ParseIP(state.IP()[:idx]) + } else { + ip = net.ParseIP(state.IP()) + } + + // if the parsing did not return a proper response then we simply return 'actionBlock' to + // block the query + if ip == nil { + log.Errorf("Blocking request. Unable to parse source address: %v", state.IP()) + return actionBlock + } + qtype := state.QType() + for _, policy := range policies { + // dns.TypeNone matches all query types. + _, matchAll := policy.qtypes[dns.TypeNone] + _, match := policy.qtypes[qtype] + if !matchAll && !match { + continue + } + + _, contained := policy.filter.GetByIP(ip) + if !contained { + continue + } + + // matched. + return policy.action + } + return actionNone +} + +// Name implements the plugin.Handler interface. +func (a ACL) Name() string { + return "acl" +} diff --git a/plugin/acl/acl_test.go b/plugin/acl/acl_test.go new file mode 100644 index 0000000..f867d1f --- /dev/null +++ b/plugin/acl/acl_test.go @@ -0,0 +1,599 @@ +package acl + +import ( + "context" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +type testResponseWriter struct { + test.ResponseWriter + Rcode int + Msg *dns.Msg +} + +func (t *testResponseWriter) setRemoteIP(ip string) { + t.RemoteIP = ip +} + +func (t *testResponseWriter) setZone(zone string) { + t.Zone = zone +} + +// WriteMsg implement dns.ResponseWriter interface. +func (t *testResponseWriter) WriteMsg(m *dns.Msg) error { + t.Rcode = m.Rcode + t.Msg = m + return nil +} + +func NewTestControllerWithZones(input string, zones []string) *caddy.Controller { + ctr := caddy.NewTestController("dns", input) + ctr.ServerBlockKeys = append(ctr.ServerBlockKeys, zones...) + return ctr +} + +func TestACLServeDNS(t *testing.T) { + type args struct { + domain string + sourceIP string + qtype uint16 + } + tests := []struct { + name string + config string + zones []string + args args + wantRcode int + wantErr bool + wantExtendedErrorCode uint16 + expectNoResponse bool + }{ + // IPv4 tests. + { + name: "Blacklist 1 BLOCKED", + config: `acl example.org { + block type A net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 1 ALLOWED", + config: `acl example.org { + block type A net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.167.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Blacklist 2 BLOCKED", + config: ` + acl example.org { + block type * net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.0.2", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 3 BLOCKED", + config: `acl example.org { + block type A + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "10.1.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 3 ALLOWED", + config: `acl example.org { + block type A + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "10.1.0.2", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Blacklist 4 Single IP BLOCKED", + config: `acl example.org { + block type A net 192.168.1.2 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 4 Single IP ALLOWED", + config: `acl example.org { + block type A net 192.168.1.2 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.3", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Filter 1 FILTERED", + config: `acl example.org { + filter type A net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + wantExtendedErrorCode: dns.ExtendedErrorCodeFiltered, + }, + { + name: "Filter 1 ALLOWED", + config: `acl example.org { + filter type A net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.167.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Whitelist 1 ALLOWED", + config: `acl example.org { + allow net 192.168.0.0/16 + block + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Whitelist 1 REFUSED", + config: `acl example.org { + allow type * net 192.168.0.0/16 + block + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "10.1.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Drop 1 DROPPED", + config: `acl example.org { + drop net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.0.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + expectNoResponse: true, + }, + { + name: "Subnet-Order 1 REFUSED", + config: `acl example.org { + block net 192.168.1.0/24 + drop net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Subnet-Order 2 DROPPED", + config: `acl example.org { + drop net 192.168.0.0/16 + block net 192.168.1.0/24 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.1", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + expectNoResponse: true, + }, + { + name: "Drop-Type 1 DROPPED", + config: `acl example.org { + drop type A + allow net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.1", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + expectNoResponse: true, + }, + { + name: "Drop-Type 2 ALLOWED", + config: `acl example.org { + drop type A + allow net 192.168.0.0/16 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.1", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Fine-Grained 1 REFUSED", + config: `acl a.example.org { + block type * net 192.168.1.0/24 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "a.example.org.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Fine-Grained 1 ALLOWED", + config: `acl a.example.org { + block net 192.168.1.0/24 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "www.example.org.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Fine-Grained 2 REFUSED", + config: `acl example.org { + block net 192.168.1.0/24 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "a.example.org.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Fine-Grained 2 ALLOWED", + config: `acl { + block net 192.168.1.0/24 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "a.example.com.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Fine-Grained 3 REFUSED", + config: `acl a.example.org { + block net 192.168.1.0/24 + } + acl b.example.org { + block type * net 192.168.2.0/24 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "b.example.org.", + sourceIP: "192.168.2.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Fine-Grained 3 ALLOWED", + config: `acl a.example.org { + block net 192.168.1.0/24 + } + acl b.example.org { + block net 192.168.2.0/24 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "b.example.org.", + sourceIP: "192.168.1.2", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + // IPv6 tests. + { + name: "Blacklist 1 BLOCKED IPv6", + config: `acl example.org { + block type A net 2001:db8:abcd:0012::0/64 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:db8:abcd:0012::1230", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 1 ALLOWED IPv6", + config: `acl example.org { + block type A net 2001:db8:abcd:0012::0/64 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:db8:abcd:0013::0", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Blacklist 2 BLOCKED IPv6", + config: `acl example.org { + block type A + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 3 Single IP BLOCKED IPv6", + config: `acl example.org { + block type A net 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Blacklist 3 Single IP ALLOWED IPv6", + config: `acl example.org { + block type A net 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7335", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Fine-Grained 1 REFUSED IPv6", + config: `acl a.example.org { + block type * net 2001:db8:abcd:0012::0/64 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "a.example.org.", + sourceIP: "2001:db8:abcd:0012:2019::0", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Fine-Grained 1 ALLOWED IPv6", + config: `acl a.example.org { + block net 2001:db8:abcd:0012::0/64 + }`, + zones: []string{"example.org"}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:db8:abcd:0012:2019::0", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + }, + { + name: "Blacklist Address%ifname", + config: `acl example.org { + block type AAAA net 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + }`, + zones: []string{"eth0"}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Drop 1 DROPPED IPV6", + config: `acl example.org { + drop net 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeSuccess, + expectNoResponse: true, + }, + { + name: "Subnet-Order 1 REFUSED IPv6", + config: `acl example.org { + block net 2001:db8:abcd:0012:8000::/66 + drop net 2001:db8:abcd:0012::0/64 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:db8:abcd:0012:8000::1", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeRefused, + wantExtendedErrorCode: dns.ExtendedErrorCodeBlocked, + }, + { + name: "Subnet-Order 2 DROPPED IPv6", + config: `acl example.org { + drop net 2001:db8:abcd:0012::0/64 + block net 2001:db8:abcd:0012:8000::/66 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:db8:abcd:0012:8000::1", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeSuccess, + expectNoResponse: true, + }, + { + name: "Drop-Type 1 DROPPED IPv6", + config: `acl example.org { + drop type A + allow net 2001:db8:85a3:0000::0/64 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + qtype: dns.TypeA, + }, + wantRcode: dns.RcodeSuccess, + expectNoResponse: true, + }, + { + name: "Drop-Type 2 ALLOWED IPv6", + config: `acl example.org { + drop type A + allow net 2001:db8:85a3:0000::0/64 + }`, + zones: []string{}, + args: args{ + domain: "www.example.org.", + sourceIP: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + qtype: dns.TypeAAAA, + }, + wantRcode: dns.RcodeSuccess, + }, + } + + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctr := NewTestControllerWithZones(tt.config, tt.zones) + a, err := parse(ctr) + a.Next = test.NextHandler(dns.RcodeSuccess, nil) + if err != nil { + t.Errorf("Error: Cannot parse acl from config: %v", err) + return + } + + w := &testResponseWriter{} + m := new(dns.Msg) + w.setRemoteIP(tt.args.sourceIP) + if len(tt.zones) > 0 { + w.setZone(tt.zones[0]) + } + m.SetQuestion(tt.args.domain, tt.args.qtype) + _, err = a.ServeDNS(ctx, w, m) + if (err != nil) != tt.wantErr { + t.Errorf("Error: acl.ServeDNS() error = %v, wantErr %v", err, tt.wantErr) + return + } + if w.Rcode != tt.wantRcode { + t.Errorf("Error: acl.ServeDNS() Rcode = %v, want %v", w.Rcode, tt.wantRcode) + } + if tt.expectNoResponse && w.Msg != nil { + t.Errorf("Error: acl.ServeDNS() responded to client when not expected") + } + if tt.wantExtendedErrorCode != 0 { + matched := false + for _, opt := range w.Msg.IsEdns0().Option { + if ede, ok := opt.(*dns.EDNS0_EDE); ok { + if ede.InfoCode != tt.wantExtendedErrorCode { + t.Errorf("Error: acl.ServeDNS() Extended DNS Error = %v, want %v", ede.InfoCode, tt.wantExtendedErrorCode) + } + matched = true + } + } + if !matched { + t.Error("Error: acl.ServeDNS() missing Extended DNS Error option") + } + } + }) + } +} diff --git a/plugin/acl/metrics.go b/plugin/acl/metrics.go new file mode 100644 index 0000000..a8d8232 --- /dev/null +++ b/plugin/acl/metrics.go @@ -0,0 +1,39 @@ +package acl + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // RequestBlockCount is the number of DNS requests being blocked. + RequestBlockCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: pluginName, + Name: "blocked_requests_total", + Help: "Counter of DNS requests being blocked.", + }, []string{"server", "zone", "view"}) + // RequestFilterCount is the number of DNS requests being filtered. + RequestFilterCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: pluginName, + Name: "filtered_requests_total", + Help: "Counter of DNS requests being filtered.", + }, []string{"server", "zone", "view"}) + // RequestAllowCount is the number of DNS requests being Allowed. + RequestAllowCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: pluginName, + Name: "allowed_requests_total", + Help: "Counter of DNS requests being allowed.", + }, []string{"server", "view"}) + // RequestDropCount is the number of DNS requests being dropped. + RequestDropCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: pluginName, + Name: "dropped_requests_total", + Help: "Counter of DNS requests being dropped.", + }, []string{"server", "zone", "view"}) +) diff --git a/plugin/acl/setup.go b/plugin/acl/setup.go new file mode 100644 index 0000000..189acc6 --- /dev/null +++ b/plugin/acl/setup.go @@ -0,0 +1,154 @@ +package acl + +import ( + "net" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + + "github.com/infobloxopen/go-trees/iptree" + "github.com/miekg/dns" +) + +const pluginName = "acl" + +func init() { plugin.Register(pluginName, setup) } + +func newDefaultFilter() *iptree.Tree { + defaultFilter := iptree.NewTree() + _, IPv4All, _ := net.ParseCIDR("0.0.0.0/0") + _, IPv6All, _ := net.ParseCIDR("::/0") + defaultFilter.InplaceInsertNet(IPv4All, struct{}{}) + defaultFilter.InplaceInsertNet(IPv6All, struct{}{}) + return defaultFilter +} + +func setup(c *caddy.Controller) error { + a, err := parse(c) + if err != nil { + return plugin.Error(pluginName, err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + a.Next = next + return a + }) + + return nil +} + +func parse(c *caddy.Controller) (ACL, error) { + a := ACL{} + for c.Next() { + r := rule{} + args := c.RemainingArgs() + r.zones = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys) + + for c.NextBlock() { + p := policy{} + + action := strings.ToLower(c.Val()) + if action == "allow" { + p.action = actionAllow + } else if action == "block" { + p.action = actionBlock + } else if action == "filter" { + p.action = actionFilter + } else if action == "drop" { + p.action = actionDrop + } else { + return a, c.Errf("unexpected token %q; expect 'allow', 'block', 'filter' or 'drop'", c.Val()) + } + + p.qtypes = make(map[uint16]struct{}) + p.filter = iptree.NewTree() + + hasTypeSection := false + hasNetSection := false + + remainingTokens := c.RemainingArgs() + for len(remainingTokens) > 0 { + if !isPreservedIdentifier(remainingTokens[0]) { + return a, c.Errf("unexpected token %q; expect 'type | net'", remainingTokens[0]) + } + section := strings.ToLower(remainingTokens[0]) + + i := 1 + var tokens []string + for ; i < len(remainingTokens) && !isPreservedIdentifier(remainingTokens[i]); i++ { + tokens = append(tokens, remainingTokens[i]) + } + remainingTokens = remainingTokens[i:] + + if len(tokens) == 0 { + return a, c.Errf("no token specified in %q section", section) + } + + switch section { + case "type": + hasTypeSection = true + for _, token := range tokens { + if token == "*" { + p.qtypes[dns.TypeNone] = struct{}{} + break + } + qtype, ok := dns.StringToType[token] + if !ok { + return a, c.Errf("unexpected token %q; expect legal QTYPE", token) + } + p.qtypes[qtype] = struct{}{} + } + case "net": + hasNetSection = true + for _, token := range tokens { + if token == "*" { + p.filter = newDefaultFilter() + break + } + token = normalize(token) + _, source, err := net.ParseCIDR(token) + if err != nil { + return a, c.Errf("illegal CIDR notation %q", token) + } + p.filter.InplaceInsertNet(source, struct{}{}) + } + default: + return a, c.Errf("unexpected token %q; expect 'type | net'", section) + } + } + + // optional `type` section means all record types. + if !hasTypeSection { + p.qtypes[dns.TypeNone] = struct{}{} + } + + // optional `net` means all ip addresses. + if !hasNetSection { + p.filter = newDefaultFilter() + } + + r.policies = append(r.policies, p) + } + a.Rules = append(a.Rules, r) + } + return a, nil +} + +func isPreservedIdentifier(token string) bool { + identifier := strings.ToLower(token) + return identifier == "type" || identifier == "net" +} + +// normalize appends '/32' for any single IPv4 address and '/128' for IPv6. +func normalize(rawNet string) string { + if idx := strings.IndexAny(rawNet, "/"); idx >= 0 { + return rawNet + } + + if idx := strings.IndexAny(rawNet, ":"); idx >= 0 { + return rawNet + "/128" + } + return rawNet + "/32" +} diff --git a/plugin/acl/setup_test.go b/plugin/acl/setup_test.go new file mode 100644 index 0000000..5cd51bb --- /dev/null +++ b/plugin/acl/setup_test.go @@ -0,0 +1,273 @@ +package acl + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + name string + config string + wantErr bool + }{ + // IPv4 tests. + { + "Blacklist 1", + `acl { + block type A net 192.168.0.0/16 + }`, + false, + }, + { + "Blacklist 2", + `acl { + block type * net 192.168.0.0/16 + }`, + false, + }, + { + "Blacklist 3", + `acl { + block type A net * + }`, + false, + }, + { + "Blacklist 4", + `acl { + allow type * net 192.168.1.0/24 + block type * net 192.168.0.0/16 + }`, + false, + }, + { + "Filter 1", + `acl { + filter type A net 192.168.0.0/16 + }`, + false, + }, + { + "Whitelist 1", + `acl { + allow type * net 192.168.0.0/16 + block type * net * + }`, + false, + }, + { + "Drop 1", + `acl { + drop type * net 192.168.0.0/16 + }`, + false, + }, + { + "fine-grained 1", + `acl a.example.org { + block type * net 192.168.1.0/24 + }`, + false, + }, + { + "fine-grained 2", + `acl a.example.org { + block type * net 192.168.1.0/24 + } + acl b.example.org { + block type * net 192.168.2.0/24 + }`, + false, + }, + { + "Multiple Networks 1", + `acl example.org { + block type * net 192.168.1.0/24 192.168.3.0/24 + }`, + false, + }, + { + "Multiple Qtypes 1", + `acl example.org { + block type TXT ANY CNAME net 192.168.3.0/24 + }`, + false, + }, + { + "Missing argument 1", + `acl { + block A net 192.168.0.0/16 + }`, + true, + }, + { + "Missing argument 2", + `acl { + block type net 192.168.0.0/16 + }`, + true, + }, + { + "Illegal argument 1", + `acl { + block type ABC net 192.168.0.0/16 + }`, + true, + }, + { + "Illegal argument 2", + `acl { + blck type A net 192.168.0.0/16 + }`, + true, + }, + { + "Illegal argument 3", + `acl { + block type A net 192.168.0/16 + }`, + true, + }, + { + "Illegal argument 4", + `acl { + block type A net 192.168.0.0/33 + }`, + true, + }, + // IPv6 tests. + { + "Blacklist 1 IPv6", + `acl { + block type A net 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + }`, + false, + }, + { + "Blacklist 2 IPv6", + `acl { + block type * net 2001:db8:85a3::8a2e:370:7334 + }`, + false, + }, + { + "Blacklist 3 IPv6", + `acl { + block type A + }`, + false, + }, + { + "Blacklist 4 IPv6", + `acl { + allow net 2001:db8:abcd:0012::0/64 + block net 2001:db8:abcd:0012::0/48 + }`, + false, + }, + { + "Filter 1 IPv6", + `acl { + filter type A net 2001:0db8:85a3:0000:0000:8a2e:0370:7334 + }`, + false, + }, + { + "Whitelist 1 IPv6", + `acl { + allow net 2001:db8:abcd:0012::0/64 + block + }`, + false, + }, + { + "Drop 1 IPv6", + `acl { + drop net 2001:db8:abcd:0012::0/64 + }`, + false, + }, + { + "fine-grained 1 IPv6", + `acl a.example.org { + block net 2001:db8:abcd:0012::0/64 + }`, + false, + }, + { + "fine-grained 2 IPv6", + `acl a.example.org { + block net 2001:db8:abcd:0012::0/64 + } + acl b.example.org { + block net 2001:db8:abcd:0013::0/64 + }`, + false, + }, + { + "Multiple Networks 1 IPv6", + `acl example.org { + block net 2001:db8:abcd:0012::0/64 2001:db8:85a3::8a2e:370:7334/64 + }`, + false, + }, + { + "Illegal argument 1 IPv6", + `acl { + block type A net 2001::85a3::8a2e:370:7334 + }`, + true, + }, + { + "Illegal argument 2 IPv6", + `acl { + block type A net 2001:db8:85a3:::8a2e:370:7334 + }`, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctr := caddy.NewTestController("dns", tt.config) + if err := setup(ctr); (err != nil) != tt.wantErr { + t.Errorf("Error: setup() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNormalize(t *testing.T) { + type args struct { + rawNet string + } + tests := []struct { + name string + args args + want string + }{ + { + "Network range 1", + args{"10.218.10.8/24"}, + "10.218.10.8/24", + }, + { + "IP address 1", + args{"10.218.10.8"}, + "10.218.10.8/32", + }, + { + "IPv6 address 1", + args{"2001:0db8:85a3:0000:0000:8a2e:0370:7334"}, + "2001:0db8:85a3:0000:0000:8a2e:0370:7334/128", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalize(tt.args.rawNet); got != tt.want { + t.Errorf("Error: normalize() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/plugin/any/README.md b/plugin/any/README.md new file mode 100644 index 0000000..25e4ecf --- /dev/null +++ b/plugin/any/README.md @@ -0,0 +1,36 @@ + +# any + +## Name + +*any* - gives a minimal response to ANY queries. + +## Description + +*any* basically blocks ANY queries by responding to them with a short HINFO reply. See [RFC +8482](https://tools.ietf.org/html/rfc8482) for details. + +## Syntax + +~~~ txt +any +~~~ + +## Examples + +~~~ corefile +example.org { + whoami + any +} +~~~ + +A `dig +nocmd ANY example.org +noall +answer` now returns: + +~~~ txt +example.org. 8482 IN HINFO "ANY obsoleted" "See RFC 8482" +~~~ + +## See Also + +[RFC 8482](https://tools.ietf.org/html/rfc8482). diff --git a/plugin/any/any.go b/plugin/any/any.go new file mode 100644 index 0000000..9a05e37 --- /dev/null +++ b/plugin/any/any.go @@ -0,0 +1,32 @@ +package any + +import ( + "context" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +// Any is a plugin that returns a HINFO reply to ANY queries. +type Any struct { + Next plugin.Handler +} + +// ServeDNS implements the plugin.Handler interface. +func (a Any) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if r.Question[0].Qtype != dns.TypeANY { + return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r) + } + + m := new(dns.Msg) + m.SetReply(r) + hdr := dns.RR_Header{Name: r.Question[0].Name, Ttl: 8482, Class: dns.ClassINET, Rrtype: dns.TypeHINFO} + m.Answer = []dns.RR{&dns.HINFO{Hdr: hdr, Cpu: "ANY obsoleted", Os: "See RFC 8482"}} + + w.WriteMsg(m) + return 0, nil +} + +// Name implements the Handler interface. +func (a Any) Name() string { return "any" } diff --git a/plugin/any/any_test.go b/plugin/any/any_test.go new file mode 100644 index 0000000..85df7d6 --- /dev/null +++ b/plugin/any/any_test.go @@ -0,0 +1,28 @@ +package any + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestAny(t *testing.T) { + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeANY) + a := &Any{} + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := a.ServeDNS(context.TODO(), rec, req) + + if err != nil { + t.Errorf("Expected no error, but got %q", err) + } + + if rec.Msg.Answer[0].(*dns.HINFO).Cpu != "ANY obsoleted" { + t.Errorf("Expected HINFO, but got %q", rec.Msg.Answer[0].(*dns.HINFO).Cpu) + } +} diff --git a/plugin/any/setup.go b/plugin/any/setup.go new file mode 100644 index 0000000..5c8a93b --- /dev/null +++ b/plugin/any/setup.go @@ -0,0 +1,20 @@ +package any + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("any", setup) } + +func setup(c *caddy.Controller) error { + a := Any{} + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + a.Next = next + return a + }) + + return nil +} diff --git a/plugin/auto/README.md b/plugin/auto/README.md new file mode 100644 index 0000000..661e419 --- /dev/null +++ b/plugin/auto/README.md @@ -0,0 +1,82 @@ +# auto + +## Name + +*auto* - enables serving zone data from an RFC 1035-style master file, which is automatically picked up from disk. + +## Description + +The *auto* plugin is used for an "old-style" DNS server. It serves from a preloaded file that exists +on disk. If the zone file contains signatures (i.e. is signed, i.e. using DNSSEC) correct DNSSEC answers +are returned. Only NSEC is supported! If you use this setup *you* are responsible for re-signing the +zonefile. New or changed zones are automatically picked up from disk only when SOA's serial changes. If the zones are not updated via a zone transfer, the serial must be manually changed. + +## Syntax + +~~~ +auto [ZONES...] { + directory DIR [REGEXP ORIGIN_TEMPLATE] + reload DURATION +} +~~~ + +**ZONES** zones it should be authoritative for. If empty, the zones from the configuration block +are used. + +* `directory` loads zones from the specified **DIR**. If a file name matches **REGEXP** it will be + used to extract the origin. **ORIGIN_TEMPLATE** will be used as a template for the origin. Strings + like `{<number>}` are replaced with the respective matches in the file name, e.g. `{1}` is the + first match, `{2}` is the second. The default is: `db\.(.*) {1}` i.e. from a file with the + name `db.example.com`, the extracted origin will be `example.com`. +* `reload` interval to perform reloads of zones if SOA version changes and zonefiles. It specifies how often CoreDNS should scan the directory to watch for file removal and addition. Default is one minute. + Value of `0` means to not scan for changes and reload. eg. `30s` checks zonefile every 30 seconds + and reloads zone when serial changes. + +For enabling zone transfers look at the *transfer* plugin. + +All directives from the *file* plugin are supported. Note that *auto* will load all zones found, +even though the directive might only receive queries for a specific zone. I.e: + +~~~ corefile +. { + auto example.org { + directory /etc/coredns/zones + } +} +~~~ +Will happily pick up a zone for `example.COM`, except it will never be queried, because the *auto* +directive only is authoritative for `example.ORG`. + +## Examples + +Load `org` domains from `/etc/coredns/zones/org` and allow transfers to the internet, but send +notifies to 10.240.1.1 + +~~~ corefile +org { + auto { + directory /etc/coredns/zones/org + } + transfer { + to * + to 10.240.1.1 + } +} +~~~ + +Load `org` domains from `/etc/coredns/zones/org` and looks for file names as `www.db.example.org`, +where `example.org` is the origin. Scan every 45 seconds. + +~~~ corefile +org { + auto { + directory /etc/coredns/zones/org www\.db\.(.*) {1} + reload 45s + } +} +~~~ + +## Also + +Use the *root* plugin to help you specify the location of the zone files. See the *transfer* plugin +to enable outgoing zone transfers. diff --git a/plugin/auto/auto.go b/plugin/auto/auto.go new file mode 100644 index 0000000..581004b --- /dev/null +++ b/plugin/auto/auto.go @@ -0,0 +1,100 @@ +// Package auto implements an on-the-fly loading file backend. +package auto + +import ( + "context" + "regexp" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/plugin/transfer" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type ( + // Auto holds the zones and the loader configuration for automatically loading zones. + Auto struct { + Next plugin.Handler + *Zones + + metrics *metrics.Metrics + transfer *transfer.Transfer + loader + } + + loader struct { + directory string + template string + re *regexp.Regexp + + ReloadInterval time.Duration + upstream *upstream.Upstream // Upstream for looking up names during the resolution process. + } +) + +// ServeDNS implements the plugin.Handler interface. +func (a Auto) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + + // Precheck with the origins, i.e. are we allowed to look here? + zone := plugin.Zones(a.Zones.Origins()).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r) + } + + // Now the real zone. + zone = plugin.Zones(a.Zones.Names()).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r) + } + + a.Zones.RLock() + z, ok := a.Zones.Z[zone] + a.Zones.RUnlock() + + if !ok || z == nil { + return dns.RcodeServerFailure, nil + } + + // If transfer is not loaded, we'll see these, answer with refused (no transfer allowed). + if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR { + return dns.RcodeRefused, nil + } + + answer, ns, extra, result := z.Lookup(ctx, state, qname) + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + m.Answer, m.Ns, m.Extra = answer, ns, extra + + switch result { + case file.Success: + case file.NoData: + case file.NameError: + m.Rcode = dns.RcodeNameError + case file.Delegation: + m.Authoritative = false + case file.ServerFailure: + // If the result is SERVFAIL and the answer is non-empty, then the SERVFAIL came from an + // external CNAME lookup and the answer contains the CNAME with no target record. We should + // write the CNAME record to the client instead of sending an empty SERVFAIL response. + if len(m.Answer) == 0 { + return dns.RcodeServerFailure, nil + } + // The rcode in the response should be the rcode received from the target lookup. RFC 6604 section 3 + m.Rcode = dns.RcodeServerFailure + } + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +// Name implements the Handler interface. +func (a Auto) Name() string { return "auto" } diff --git a/plugin/auto/log_test.go b/plugin/auto/log_test.go new file mode 100644 index 0000000..6047eeb --- /dev/null +++ b/plugin/auto/log_test.go @@ -0,0 +1,5 @@ +package auto + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/auto/regexp.go b/plugin/auto/regexp.go new file mode 100644 index 0000000..fa424ec --- /dev/null +++ b/plugin/auto/regexp.go @@ -0,0 +1,20 @@ +package auto + +// rewriteToExpand rewrites our template string to one that we can give to regexp.ExpandString. This basically +// involves prefixing any '{' with a '$'. +func rewriteToExpand(s string) string { + // Pretty dumb at the moment, every { will get a $ prefixed. + // Also wasteful as we build the string with +=. This is OKish + // as we do this during config parsing. + + copy := "" + + for _, c := range s { + if c == '{' { + copy += "$" + } + copy += string(c) + } + + return copy +} diff --git a/plugin/auto/regexp_test.go b/plugin/auto/regexp_test.go new file mode 100644 index 0000000..17c35eb --- /dev/null +++ b/plugin/auto/regexp_test.go @@ -0,0 +1,20 @@ +package auto + +import "testing" + +func TestRewriteToExpand(t *testing.T) { + tests := []struct { + in string + expected string + }{ + {in: "", expected: ""}, + {in: "{1}", expected: "${1}"}, + {in: "{1", expected: "${1"}, + } + for i, tc := range tests { + got := rewriteToExpand(tc.in) + if got != tc.expected { + t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expected, got) + } + } +} diff --git a/plugin/auto/setup.go b/plugin/auto/setup.go new file mode 100644 index 0000000..bd94797 --- /dev/null +++ b/plugin/auto/setup.go @@ -0,0 +1,172 @@ +package auto + +import ( + "errors" + "os" + "path/filepath" + "regexp" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/plugin/transfer" +) + +var log = clog.NewWithPlugin("auto") + +func init() { plugin.Register("auto", setup) } + +func setup(c *caddy.Controller) error { + a, err := autoParse(c) + if err != nil { + return plugin.Error("auto", err) + } + + c.OnStartup(func() error { + m := dnsserver.GetConfig(c).Handler("prometheus") + if m != nil { + (&a).metrics = m.(*metrics.Metrics) + } + t := dnsserver.GetConfig(c).Handler("transfer") + if t != nil { + (&a).transfer = t.(*transfer.Transfer) + } + return nil + }) + + walkChan := make(chan bool) + + c.OnStartup(func() error { + err := a.Walk() + if err != nil { + return err + } + if err := a.Notify(); err != nil { + log.Warning(err) + } + if a.loader.ReloadInterval == 0 { + return nil + } + go func() { + ticker := time.NewTicker(a.loader.ReloadInterval) + defer ticker.Stop() + for { + select { + case <-walkChan: + return + case <-ticker.C: + a.Walk() + if err := a.Notify(); err != nil { + log.Warning(err) + } + } + } + }() + return nil + }) + + c.OnShutdown(func() error { + close(walkChan) + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + a.Next = next + return a + }) + + return nil +} + +func autoParse(c *caddy.Controller) (Auto, error) { + nilInterval := -1 * time.Second + var a = Auto{ + loader: loader{ + template: "${1}", + re: regexp.MustCompile(`db\.(.*)`), + ReloadInterval: nilInterval, + }, + Zones: &Zones{}, + } + + config := dnsserver.GetConfig(c) + + for c.Next() { + // auto [ZONES...] + args := c.RemainingArgs() + a.Zones.origins = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys) + a.loader.upstream = upstream.New() + + for c.NextBlock() { + switch c.Val() { + case "directory": // directory DIR [REGEXP TEMPLATE] + if !c.NextArg() { + return a, c.ArgErr() + } + a.loader.directory = c.Val() + if !filepath.IsAbs(a.loader.directory) && config.Root != "" { + a.loader.directory = filepath.Join(config.Root, a.loader.directory) + } + _, err := os.Stat(a.loader.directory) + if err != nil { + if os.IsNotExist(err) { + log.Warningf("Directory does not exist: %s", a.loader.directory) + } else { + return a, c.Errf("Unable to access root path '%s': %v", a.loader.directory, err) + } + } + + // regexp template + if c.NextArg() { + a.loader.re, err = regexp.Compile(c.Val()) + if err != nil { + return a, err + } + if a.loader.re.NumSubexp() == 0 { + return a, c.Errf("Need at least one sub expression") + } + + if !c.NextArg() { + return a, c.ArgErr() + } + a.loader.template = rewriteToExpand(c.Val()) + } + + if c.NextArg() { + return Auto{}, c.ArgErr() + } + + case "reload": + t := c.RemainingArgs() + if len(t) < 1 { + return a, errors.New("reload duration value is expected") + } + d, err := time.ParseDuration(t[0]) + if d < 0 { + err = errors.New("invalid duration") + } + if err != nil { + return a, plugin.Error("file", err) + } + a.loader.ReloadInterval = d + + case "upstream": + // remove soon + c.RemainingArgs() // eat remaining args + + default: + return Auto{}, c.Errf("unknown property '%s'", c.Val()) + } + } + } + + if a.loader.ReloadInterval == nilInterval { + a.loader.ReloadInterval = 60 * time.Second + } + + return a, nil +} diff --git a/plugin/auto/setup_test.go b/plugin/auto/setup_test.go new file mode 100644 index 0000000..4fada6f --- /dev/null +++ b/plugin/auto/setup_test.go @@ -0,0 +1,177 @@ +package auto + +import ( + "testing" + "time" + + "github.com/coredns/caddy" +) + +func TestAutoParse(t *testing.T) { + tests := []struct { + inputFileRules string + shouldErr bool + expectedDirectory string + expectedTempl string + expectedRe string + expectedReloadInterval time.Duration + }{ + { + `auto example.org { + directory /tmp + }`, + false, "/tmp", "${1}", `db\.(.*)`, 60 * time.Second, + }, + { + `auto 10.0.0.0/24 { + directory /tmp + }`, + false, "/tmp", "${1}", `db\.(.*)`, 60 * time.Second, + }, + { + `auto { + directory /tmp + reload 0 + }`, + false, "/tmp", "${1}", `db\.(.*)`, 0 * time.Second, + }, + { + `auto { + directory /tmp (.*) bliep + }`, + false, "/tmp", "bliep", `(.*)`, 60 * time.Second, + }, + { + `auto { + directory /tmp (.*) bliep + reload 10s + }`, + false, "/tmp", "bliep", `(.*)`, 10 * time.Second, + }, + // errors + // NO_RELOAD has been deprecated. + { + `auto { + directory /tmp + no_reload + }`, + true, "/tmp", "${1}", `db\.(.*)`, 0 * time.Second, + }, + // TIMEOUT has been deprecated. + { + `auto { + directory /tmp (.*) bliep 10 + }`, + true, "/tmp", "bliep", `(.*)`, 10 * time.Second, + }, + // TRANSFER has been deprecated. + { + `auto { + directory /tmp (.*) bliep 10 + transfer to 127.0.0.1 + }`, + true, "/tmp", "bliep", `(.*)`, 10 * time.Second, + }, + // no template specified. + { + `auto { + directory /tmp (.*) + }`, + true, "/tmp", "", `(.*)`, 60 * time.Second, + }, + // no directory specified. + { + `auto example.org { + directory + }`, + true, "", "${1}", `db\.(.*)`, 60 * time.Second, + }, + // illegal REGEXP. + { + `auto example.org { + directory /tmp * {1} + }`, + true, "/tmp", "${1}", ``, 60 * time.Second, + }, + // unexpected argument. + { + `auto example.org { + directory /tmp (.*) {1} aa + }`, + true, "/tmp", "${1}", ``, 60 * time.Second, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + a, err := autoParse(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } else if !test.shouldErr { + if a.loader.directory != test.expectedDirectory { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedDirectory, a.loader.directory) + } + if a.loader.template != test.expectedTempl { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedTempl, a.loader.template) + } + if a.loader.re.String() != test.expectedRe { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedRe, a.loader.re) + } + if a.loader.ReloadInterval != test.expectedReloadInterval { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedReloadInterval, a.loader.ReloadInterval) + } + } + } +} + +func TestSetupReload(t *testing.T) { + tests := []struct { + name string + config string + wantErr bool + }{ + { + name: "reload valid", + config: `auto { + directory . + reload 5s + }`, + wantErr: false, + }, + { + name: "reload disable", + config: `auto { + directory . + reload 0 + }`, + wantErr: false, + }, + { + name: "reload invalid", + config: `auto { + directory . + reload -1s + }`, + wantErr: true, + }, + { + name: "reload invalid", + config: `auto { + directory . + reload + }`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctr := caddy.NewTestController("dns", tt.config) + if err := setup(ctr); (err != nil) != tt.wantErr { + t.Errorf("Error: setup() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/plugin/auto/walk.go b/plugin/auto/walk.go new file mode 100644 index 0000000..38f6375 --- /dev/null +++ b/plugin/auto/walk.go @@ -0,0 +1,107 @@ +package auto + +import ( + "os" + "path/filepath" + "regexp" + + "github.com/coredns/coredns/plugin/file" + + "github.com/miekg/dns" +) + +// Walk will recursively walk of the file under l.directory and adds the one that match l.re. +func (a Auto) Walk() error { + // TODO(miek): should add something so that we don't stomp on each other. + + toDelete := make(map[string]bool) + for _, n := range a.Zones.Names() { + toDelete[n] = true + } + + filepath.Walk(a.loader.directory, func(path string, info os.FileInfo, e error) error { + if e != nil { + log.Warningf("error reading %v: %v", path, e) + } + if info == nil || info.IsDir() { + return nil + } + + match, origin := matches(a.loader.re, info.Name(), a.loader.template) + if !match { + return nil + } + + if z, ok := a.Zones.Z[origin]; ok { + // we already have this zone + toDelete[origin] = false + z.SetFile(path) + return nil + } + + reader, err := os.Open(filepath.Clean(path)) + if err != nil { + log.Warningf("Opening %s failed: %s", path, err) + return nil + } + defer reader.Close() + + // Serial for loading a zone is 0, because it is a new zone. + zo, err := file.Parse(reader, origin, path, 0) + if err != nil { + log.Warningf("Parse zone `%s': %v", origin, err) + return nil + } + + zo.ReloadInterval = a.loader.ReloadInterval + zo.Upstream = a.loader.upstream + + a.Zones.Add(zo, origin, a.transfer) + + if a.metrics != nil { + a.metrics.AddZone(origin) + } + + log.Infof("Inserting zone `%s' from: %s", origin, path) + + toDelete[origin] = false + + return nil + }) + + for origin, ok := range toDelete { + if !ok { + continue + } + + if a.metrics != nil { + a.metrics.RemoveZone(origin) + } + + a.Zones.Remove(origin) + + log.Infof("Deleting zone `%s'", origin) + } + + return nil +} + +// matches re to filename, if it is a match, the subexpression will be used to expand +// template to an origin. When match is true that origin is returned. Origin is fully qualified. +func matches(re *regexp.Regexp, filename, template string) (match bool, origin string) { + base := filepath.Base(filename) + + matches := re.FindStringSubmatchIndex(base) + if matches == nil { + return false, "" + } + + by := re.ExpandString(nil, template, base, matches) + if by == nil { + return false, "" + } + + origin = dns.Fqdn(string(by)) + + return true, origin +} diff --git a/plugin/auto/walk_test.go b/plugin/auto/walk_test.go new file mode 100644 index 0000000..062c992 --- /dev/null +++ b/plugin/auto/walk_test.go @@ -0,0 +1,81 @@ +package auto + +import ( + "os" + "path/filepath" + "regexp" + "testing" +) + +var dbFiles = []string{"db.example.org", "aa.example.org"} + +const zoneContent = `; testzone +@ IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082534 7200 3600 1209600 3600 + NS a.iana-servers.net. + NS b.iana-servers.net. + +www IN A 127.0.0.1 +` + +func TestWalk(t *testing.T) { + tempdir, err := createFiles(t) + if err != nil { + t.Fatal(err) + } + + ldr := loader{ + directory: tempdir, + re: regexp.MustCompile(`db\.(.*)`), + template: `${1}`, + } + + a := Auto{ + loader: ldr, + Zones: &Zones{}, + } + + a.Walk() + + // db.example.org and db.example.com should be here (created in createFiles) + for _, name := range []string{"example.com.", "example.org."} { + if _, ok := a.Zones.Z[name]; !ok { + t.Errorf("%s should have been added", name) + } + } +} + +func TestWalkNonExistent(t *testing.T) { + nonExistingDir := "highly_unlikely_to_exist_dir" + + ldr := loader{ + directory: nonExistingDir, + re: regexp.MustCompile(`db\.(.*)`), + template: `${1}`, + } + + a := Auto{ + loader: ldr, + Zones: &Zones{}, + } + + a.Walk() +} + +func createFiles(t *testing.T) (string, error) { + dir := t.TempDir() + + for _, name := range dbFiles { + if err := os.WriteFile(filepath.Join(dir, name), []byte(zoneContent), 0644); err != nil { + return dir, err + } + } + // symlinks + if err := os.Symlink(filepath.Join(dir, "db.example.org"), filepath.Join(dir, "db.example.com")); err != nil { + return dir, err + } + if err := os.Symlink(filepath.Join(dir, "db.example.org"), filepath.Join(dir, "aa.example.com")); err != nil { + return dir, err + } + + return dir, nil +} diff --git a/plugin/auto/watcher_test.go b/plugin/auto/watcher_test.go new file mode 100644 index 0000000..9a256f4 --- /dev/null +++ b/plugin/auto/watcher_test.go @@ -0,0 +1,92 @@ +package auto + +import ( + "os" + "path/filepath" + "regexp" + "testing" +) + +func TestWatcher(t *testing.T) { + tempdir, err := createFiles(t) + if err != nil { + t.Fatal(err) + } + + ldr := loader{ + directory: tempdir, + re: regexp.MustCompile(`db\.(.*)`), + template: `${1}`, + } + + a := Auto{ + loader: ldr, + Zones: &Zones{}, + } + + a.Walk() + + // example.org and example.com should exist, we have 3 apex rrs and 1 "real" record. All() returns the non-apex ones. + if x := len(a.Zones.Z["example.org."].All()); x != 1 { + t.Fatalf("Expected 1 RRs, got %d", x) + } + if x := len(a.Zones.Z["example.com."].All()); x != 1 { + t.Fatalf("Expected 1 RRs, got %d", x) + } + + // Now remove one file, rescan and see if it's gone. + if err := os.Remove(filepath.Join(tempdir, "db.example.com")); err != nil { + t.Fatal(err) + } + + a.Walk() + + if _, ok := a.Zones.Z["example.com."]; ok { + t.Errorf("Expected %q to be gone.", "example.com.") + } + if _, ok := a.Zones.Z["example.org."]; !ok { + t.Errorf("Expected %q to still be there.", "example.org.") + } +} + +func TestSymlinks(t *testing.T) { + tempdir, err := createFiles(t) + if err != nil { + t.Fatal(err) + } + + ldr := loader{ + directory: tempdir, + re: regexp.MustCompile(`db\.(.*)`), + template: `${1}`, + } + + a := Auto{ + loader: ldr, + Zones: &Zones{}, + } + + a.Walk() + + // Now create a duplicate file in a subdirectory and repoint the symlink + if err := os.Remove(filepath.Join(tempdir, "db.example.com")); err != nil { + t.Fatal(err) + } + dataDir := filepath.Join(tempdir, "..data") + if err = os.Mkdir(dataDir, 0755); err != nil { + t.Fatal(err) + } + newFile := filepath.Join(dataDir, "db.example.com") + if err = os.Symlink(filepath.Join(tempdir, "db.example.org"), newFile); err != nil { + t.Fatal(err) + } + + a.Walk() + + if storedZone, ok := a.Zones.Z["example.com."]; ok { + storedFile := storedZone.File() + if storedFile != newFile { + t.Errorf("Expected %q to reflect new path %q", storedFile, newFile) + } + } +} diff --git a/plugin/auto/xfr.go b/plugin/auto/xfr.go new file mode 100644 index 0000000..e6a9ba5 --- /dev/null +++ b/plugin/auto/xfr.go @@ -0,0 +1,31 @@ +package auto + +import ( + "github.com/coredns/coredns/plugin/transfer" + + "github.com/miekg/dns" +) + +// Transfer implements the transfer.Transfer interface. +func (a Auto) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + a.Zones.RLock() + z, ok := a.Zones.Z[zone] + a.Zones.RUnlock() + + if !ok || z == nil { + return nil, transfer.ErrNotAuthoritative + } + return z.Transfer(serial) +} + +// Notify sends notifies for all zones with secondaries configured with the transfer plugin +func (a Auto) Notify() error { + var err error + for _, origin := range a.Zones.Names() { + e := a.transfer.Notify(origin) + if e != nil { + err = e + } + } + return err +} diff --git a/plugin/auto/zone.go b/plugin/auto/zone.go new file mode 100644 index 0000000..bb81186 --- /dev/null +++ b/plugin/auto/zone.go @@ -0,0 +1,77 @@ +// Package auto implements a on-the-fly loading file backend. +package auto + +import ( + "sync" + + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/transfer" +) + +// Zones maps zone names to a *Zone. This keeps track of what zones we have loaded at +// any one time. +type Zones struct { + Z map[string]*file.Zone // A map mapping zone (origin) to the Zone's data. + names []string // All the keys from the map Z as a string slice. + + origins []string // Any origins from the server block. + + sync.RWMutex +} + +// Names returns the names from z. +func (z *Zones) Names() []string { + z.RLock() + n := z.names + z.RUnlock() + return n +} + +// Origins returns the origins from z. +func (z *Zones) Origins() []string { + // doesn't need locking, because there aren't multiple Go routines accessing it. + return z.origins +} + +// Zones returns a zone with origin name from z, nil when not found. +func (z *Zones) Zones(name string) *file.Zone { + z.RLock() + zo := z.Z[name] + z.RUnlock() + return zo +} + +// Add adds a new zone into z. If z.ReloadInterval is not zero, the +// reload goroutine is started. +func (z *Zones) Add(zo *file.Zone, name string, t *transfer.Transfer) { + z.Lock() + + if z.Z == nil { + z.Z = make(map[string]*file.Zone) + } + + z.Z[name] = zo + z.names = append(z.names, name) + zo.Reload(t) + + z.Unlock() +} + +// Remove removes the zone named name from z. It also stops the zone's reload goroutine. +func (z *Zones) Remove(name string) { + z.Lock() + + if zo, ok := z.Z[name]; ok { + zo.OnShutdown() + } + + delete(z.Z, name) + + // TODO(miek): just regenerate Names (might be bad if you have a lot of zones...) + z.names = []string{} + for n := range z.Z { + z.names = append(z.names, n) + } + + z.Unlock() +} diff --git a/plugin/autopath/README.md b/plugin/autopath/README.md new file mode 100644 index 0000000..eedbf5e --- /dev/null +++ b/plugin/autopath/README.md @@ -0,0 +1,68 @@ +# autopath + +## Name + +*autopath* - allows for server-side search path completion. + +## Description + +If the *autopath* plugin sees a query that matches the first element of the configured search path, it will +follow the chain of search path elements and return the first reply that is not NXDOMAIN. On any +failures, the original reply is returned. Because *autopath* returns a reply for a name that wasn't +the original question, it will add a CNAME that points from the original name (with the search path +element in it) to the name of this answer. + +**Note**: There are several known issues, see the "Bugs" section below. + +## Syntax + +~~~ +autopath [ZONE...] RESOLV-CONF +~~~ + +* **ZONES** zones *autopath* should be authoritative for. +* **RESOLV-CONF** points to a `resolv.conf` like file or uses a special syntax to point to another + plugin. For instance `@kubernetes`, will call out to the kubernetes plugin (for each + query) to retrieve the search list it should use. + +If a plugin implements the `AutoPather` interface then it can be used by *autopath*. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metric is exported: + +* `coredns_autopath_success_total{server}` - counter of successfully autopath-ed queries. + +The `server` label is explained in the *metrics* plugin documentation. + +## Examples + +~~~ +autopath my-resolv.conf +~~~ + +Use `my-resolv.conf` as the file to get the search path from. This file only needs to have one line: +`search domain1 domain2 ...` + +~~~ +autopath @kubernetes +~~~ + +Use the search path dynamically retrieved from the *kubernetes* plugin. + +## Bugs + +In Kubernetes, *autopath* can derive the wrong namespace of a client Pod (and therefore wrong search +path) in the following case. To properly build the search path of a client *autopath* needs to know +the namespace of the a Pod making a DNS request. To do this, it relies on the *kubernetes* plugin's +Pod cache to resolve the client's IP address to a Pod. The Pod cache is maintained by an API watch +on Pods. When Pod IP assignments change, the Kubernetes API notifies CoreDNS via the API watch. +However, that notification is not instantaneous. In the case that a Pod is deleted, and it's IP is +immediately provisioned to a Pod in another namespace, and that new Pod make a DNS lookup *before* +the API watch can notify CoreDNS of the change, *autopath* will resolve the IP to the previous Pod's +namespace. + +In Kubernetes, *autopath* is not compatible with Pods running from Windows nodes. + +If the server side search ultimately results in a negative answer (e.g. `NXDOMAIN`), then the client +will fruitlessly search all paths manually, thus negating the *autopath* optimization. diff --git a/plugin/autopath/autopath.go b/plugin/autopath/autopath.go new file mode 100644 index 0000000..f6b3488 --- /dev/null +++ b/plugin/autopath/autopath.go @@ -0,0 +1,157 @@ +/* +Package autopath implements autopathing. This is a hack; it shortcuts the +client's search path resolution by performing these lookups on the server... + +The server has a copy (via AutoPathFunc) of the client's search path and on +receiving a query it first establishes if the suffix matches the FIRST configured +element. If no match can be found the query will be forwarded up the plugin +chain without interference (if, and only if, 'fallthrough' has been set). + +If the query is deemed to fall in the search path the server will perform the +queries with each element of the search path appended in sequence until a +non-NXDOMAIN answer has been found. That reply will then be returned to the +client - with some CNAME hackery to let the client accept the reply. + +If all queries return NXDOMAIN we return the original as-is and let the client +continue searching. The client will go to the next element in the search path, +but we won’t do any more autopathing. It means that in the failure case, you do +more work, since the server looks it up, then the client still needs to go +through the search path. + +It is assume the search path ordering is identical between server and client. + +Plugins implementing autopath, must have a function called `AutoPath` of type +autopath.Func. Note the searchpath must be ending with the empty string. + +I.e: + + func (m Plugins ) AutoPath(state request.Request) []string { + return []string{"first", "second", "last", ""} + } +*/ +package autopath + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/nonwriter" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Func defines the function plugin should implement to return a search +// path to the autopath plugin. The last element of the slice must be the empty string. +// If Func returns a nil slice, no autopathing will be done. +type Func func(request.Request) []string + +// AutoPather defines the interface that a plugin should implement in order to be +// used by AutoPath. +type AutoPather interface { + AutoPath(request.Request) []string +} + +// AutoPath performs autopath: service side search path completion. +type AutoPath struct { + Next plugin.Handler + Zones []string + + // Search always includes "" as the last element, so we try the base query with out any search paths added as well. + search []string + searchFunc Func +} + +// ServeDNS implements the plugin.Handle interface. +func (a *AutoPath) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + zone := plugin.Zones(a.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r) + } + + // Check if autopath should be done, searchFunc takes precedence over the local configured search path. + var err error + searchpath := a.search + + if a.searchFunc != nil { + searchpath = a.searchFunc(state) + } + + if len(searchpath) == 0 { + return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r) + } + + if !firstInSearchPath(state.Name(), searchpath) { + return plugin.NextOrFailure(a.Name(), a.Next, ctx, w, r) + } + + origQName := state.QName() + + // Establish base name of the query. I.e what was originally asked. + base, err := dnsutil.TrimZone(state.QName(), searchpath[0]) + if err != nil { + return dns.RcodeServerFailure, err + } + + firstReply := new(dns.Msg) + firstRcode := 0 + var firstErr error + + ar := r.Copy() + // Walk the search path and see if we can get a non-nxdomain - if they all fail we return the first + // query we've done and return that as-is. This means the client will do the search path walk again... + for i, s := range searchpath { + newQName := base + "." + s + ar.Question[0].Name = newQName + nw := nonwriter.New(w) + + rcode, err := plugin.NextOrFailure(a.Name(), a.Next, ctx, nw, ar) + if err != nil { + // Return now - not sure if this is the best. We should also check if the write has happened. + return rcode, err + } + if i == 0 { + firstReply = nw.Msg + firstRcode = rcode + firstErr = err + } + + if !plugin.ClientWrite(rcode) { + continue + } + + if nw.Msg.Rcode == dns.RcodeNameError { + continue + } + + msg := nw.Msg + cnamer(msg, origQName) + + // Write whatever non-nxdomain answer we've found. + w.WriteMsg(msg) + autoPathCount.WithLabelValues(metrics.WithServer(ctx)).Add(1) + return rcode, err + } + if plugin.ClientWrite(firstRcode) { + w.WriteMsg(firstReply) + } + return firstRcode, firstErr +} + +// Name implements the Handler interface. +func (a *AutoPath) Name() string { return "autopath" } + +// firstInSearchPath checks if name is equal to are a sibling of the first element in the search path. +func firstInSearchPath(name string, searchpath []string) bool { + if name == searchpath[0] { + return true + } + if dns.IsSubDomain(searchpath[0], name) { + return true + } + return false +} diff --git a/plugin/autopath/autopath_test.go b/plugin/autopath/autopath_test.go new file mode 100644 index 0000000..5c4e554 --- /dev/null +++ b/plugin/autopath/autopath_test.go @@ -0,0 +1,166 @@ +package autopath + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var autopathTestCases = []test.Case{ + { + // search path expansion. + Qname: "b.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("b.example.org. 3600 IN CNAME b.com."), + test.A("b.com." + defaultA), + }, + }, + { + // No search path expansion + Qname: "a.example.com.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("a.example.com." + defaultA), + }, + }, +} + +func newTestAutoPath() *AutoPath { + ap := new(AutoPath) + ap.Zones = []string{"."} + ap.Next = nextHandler(map[string]int{ + "b.example.org.": dns.RcodeNameError, + "b.com.": dns.RcodeSuccess, + "a.example.com.": dns.RcodeSuccess, + }) + + ap.search = []string{"example.org.", "example.com.", "com.", ""} + return ap +} + +func TestAutoPath(t *testing.T) { + ap := newTestAutoPath() + ctx := context.TODO() + + for _, tc := range autopathTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := ap.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + continue + } + + // No sorting here as we want to check if the CNAME sits *before* the + // test of the answer. + resp := rec.Msg + + if err := test.Header(tc, resp); err != nil { + t.Error(err) + continue + } + if err := test.Section(tc, test.Answer, resp.Answer); err != nil { + t.Error(err) + } + if err := test.Section(tc, test.Ns, resp.Ns); err != nil { + t.Error(err) + } + if err := test.Section(tc, test.Extra, resp.Extra); err != nil { + t.Error(err) + } + } +} + +var autopathNoAnswerTestCases = []test.Case{ + { + // search path expansion, no answer + Qname: "c.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("b.example.org. 3600 IN CNAME b.com."), + test.A("b.com." + defaultA), + }, + }, +} + +func TestAutoPathNoAnswer(t *testing.T) { + ap := newTestAutoPath() + ctx := context.TODO() + + for _, tc := range autopathNoAnswerTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rcode, err := ap.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + continue + } + if plugin.ClientWrite(rcode) { + t.Fatalf("Expected no client write, got one for rcode %d", rcode) + } + } +} + +// nextHandler returns a Handler that returns an answer for the question in the +// request per the domain->answer map. On success an RR will be returned: "qname 3600 IN A 127.0.0.53" +func nextHandler(mm map[string]int) test.Handler { + return test.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rcode, ok := mm[r.Question[0].Name] + if !ok { + return dns.RcodeServerFailure, nil + } + + m := new(dns.Msg) + m.SetReply(r) + + switch rcode { + case dns.RcodeNameError: + m.Rcode = rcode + m.Ns = []dns.RR{soa} + w.WriteMsg(m) + return m.Rcode, nil + + case dns.RcodeSuccess: + m.Rcode = rcode + a, _ := dns.NewRR(r.Question[0].Name + defaultA) + m.Answer = []dns.RR{a} + + w.WriteMsg(m) + return m.Rcode, nil + default: + panic("nextHandler: unhandled rcode") + } + }) +} + +const defaultA = " 3600 IN A 127.0.0.53" + +var soa = func() dns.RR { + s, _ := dns.NewRR("example.org. 1800 IN SOA example.org. example.org. 1502165581 14400 3600 604800 14400") + return s +}() + +func TestInSearchPath(t *testing.T) { + a := AutoPath{search: []string{"default.svc.cluster.local.", "svc.cluster.local.", "cluster.local."}} + + tests := []struct { + qname string + b bool + }{ + {"google.com", false}, + {"default.svc.cluster.local.", true}, + {"a.default.svc.cluster.local.", true}, + {"a.b.svc.cluster.local.", false}, + } + for i, tc := range tests { + got := firstInSearchPath(tc.qname, a.search) + if got != tc.b { + t.Errorf("Test %d, got %v, expected %v", i, got, tc.b) + } + } +} diff --git a/plugin/autopath/cname.go b/plugin/autopath/cname.go new file mode 100644 index 0000000..3b2c60f --- /dev/null +++ b/plugin/autopath/cname.go @@ -0,0 +1,25 @@ +package autopath + +import ( + "strings" + + "github.com/miekg/dns" +) + +// cnamer will prefix the answer section with a cname that points from original qname to the +// name of the first RR. It will also update the question section and put original in there. +func cnamer(m *dns.Msg, original string) { + for _, a := range m.Answer { + if strings.EqualFold(original, a.Header().Name) { + continue + } + m.Answer = append(m.Answer, nil) + copy(m.Answer[1:], m.Answer) + m.Answer[0] = &dns.CNAME{ + Hdr: dns.RR_Header{Name: original, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: a.Header().Ttl}, + Target: a.Header().Name, + } + break + } + m.Question[0].Name = original +} diff --git a/plugin/autopath/metrics.go b/plugin/autopath/metrics.go new file mode 100644 index 0000000..65a6cbd --- /dev/null +++ b/plugin/autopath/metrics.go @@ -0,0 +1,18 @@ +package autopath + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // autoPathCount is counter of successfully autopath-ed queries. + autoPathCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "autopath", + Name: "success_total", + Help: "Counter of requests that did autopath.", + }, []string{"server"}) +) diff --git a/plugin/autopath/setup.go b/plugin/autopath/setup.go new file mode 100644 index 0000000..a041e36 --- /dev/null +++ b/plugin/autopath/setup.go @@ -0,0 +1,70 @@ +package autopath + +import ( + "fmt" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +func init() { plugin.Register("autopath", setup) } + +func setup(c *caddy.Controller) error { + ap, mw, err := autoPathParse(c) + if err != nil { + return plugin.Error("autopath", err) + } + + // Do this in OnStartup, so all plugin has been initialized. + c.OnStartup(func() error { + m := dnsserver.GetConfig(c).Handler(mw) + if m == nil { + return nil + } + if x, ok := m.(AutoPather); ok { + ap.searchFunc = x.AutoPath + } else { + return plugin.Error("autopath", fmt.Errorf("%s does not implement the AutoPather interface", mw)) + } + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + ap.Next = next + return ap + }) + + return nil +} + +func autoPathParse(c *caddy.Controller) (*AutoPath, string, error) { + ap := &AutoPath{} + mw := "" + + for c.Next() { + zoneAndresolv := c.RemainingArgs() + if len(zoneAndresolv) < 1 { + return ap, "", fmt.Errorf("no resolv-conf specified") + } + resolv := zoneAndresolv[len(zoneAndresolv)-1] + if strings.HasPrefix(resolv, "@") { + mw = resolv[1:] + } else { + // assume file on disk + rc, err := dns.ClientConfigFromFile(resolv) + if err != nil { + return ap, "", fmt.Errorf("failed to parse %q: %v", resolv, err) + } + ap.search = rc.Search + plugin.Zones(ap.search).Normalize() + ap.search = append(ap.search, "") // sentinel value as demanded. + } + zones := zoneAndresolv[:len(zoneAndresolv)-1] + ap.Zones = plugin.OriginsFromArgsOrServerBlock(zones, c.ServerBlockKeys) + } + return ap, mw, nil +} diff --git a/plugin/autopath/setup_test.go b/plugin/autopath/setup_test.go new file mode 100644 index 0000000..4644c7d --- /dev/null +++ b/plugin/autopath/setup_test.go @@ -0,0 +1,77 @@ +package autopath + +import ( + "os" + "reflect" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/test" +) + +func TestSetupAutoPath(t *testing.T) { + resolv, rm, err := test.TempFile(os.TempDir(), resolvConf) + if err != nil { + t.Fatalf("Could not create resolv.conf test file %s: %s", resolvConf, err) + } + defer rm() + + tests := []struct { + input string + shouldErr bool + expectedZone string + expectedMw string // expected plugin. + expectedSearch []string // expected search path + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + {`autopath @kubernetes`, false, "", "kubernetes", nil, ""}, + {`autopath example.org @kubernetes`, false, "example.org.", "kubernetes", nil, ""}, + {`autopath 10.0.0.0/8 @kubernetes`, false, "10.in-addr.arpa.", "kubernetes", nil, ""}, + {`autopath ` + resolv, false, "", "", []string{"bar.com.", "baz.com.", ""}, ""}, + // negative + {`autopath kubernetes`, true, "", "", nil, "open kubernetes: no such file or directory"}, + {`autopath`, true, "", "", nil, "no resolv-conf"}, + {`autopath ""`, true, "", "", nil, "no such file"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + ap, mw, err := autoPathParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + + if !test.shouldErr && mw != test.expectedMw { + t.Errorf("Test %d, Plugin not correctly set for input %s. Expected: %s, actual: %s", i, test.input, test.expectedMw, mw) + } + if !test.shouldErr && ap.search != nil { + if !reflect.DeepEqual(test.expectedSearch, ap.search) { + t.Errorf("Test %d, wrong searchpath for input %s. Expected: '%v', actual: '%v'", i, test.input, test.expectedSearch, ap.search) + } + } + if !test.shouldErr && test.expectedZone != "" { + if test.expectedZone != ap.Zones[0] { + t.Errorf("Test %d, expected zone %q for input %s, got: %q", i, test.expectedZone, test.input, ap.Zones[0]) + } + } + } +} + +const resolvConf = `nameserver 1.2.3.4 +domain foo.com +search bar.com baz.com +options ndots:5 +` diff --git a/plugin/azure/README.md b/plugin/azure/README.md new file mode 100644 index 0000000..f5ed5ab --- /dev/null +++ b/plugin/azure/README.md @@ -0,0 +1,60 @@ +# azure + +## Name + +*azure* - enables serving zone data from Microsoft Azure DNS service. + +## Description + +The azure plugin is useful for serving zones from Microsoft Azure DNS. The *azure* plugin supports +all the DNS records supported by Azure, viz. A, AAAA, CNAME, MX, NS, PTR, SOA, SRV, and TXT +record types. NS record type is not supported by azure private DNS. + +## Syntax + +~~~ txt +azure RESOURCE_GROUP:ZONE... { + tenant TENANT_ID + client CLIENT_ID + secret CLIENT_SECRET + subscription SUBSCRIPTION_ID + environment ENVIRONMENT + fallthrough [ZONES...] + access private +} +~~~ + +* **RESOURCE_GROUP:ZONE** is the resource group to which the hosted zones belongs on Azure, + and **ZONE** the zone that contains data. + +* **CLIENT_ID** and **CLIENT_SECRET** are the credentials for Azure, and `tenant` specifies the + **TENANT_ID** to be used. **SUBSCRIPTION_ID** is the subscription ID. All of these are needed + to access the data in Azure. + +* `environment` specifies the Azure **ENVIRONMENT**. + +* `fallthrough` If zone matches and no record can be generated, pass request to the next plugin. + If **ZONES** is omitted, then fallthrough happens for all zones for which the plugin is + authoritative. + +* `access` specifies if the zone is `public` or `private`. Default is `public`. + +## Examples + +Enable the *azure* plugin with Azure credentials for private zones `example.org`, `example.private`: + +~~~ txt +example.org { + azure resource_group_foo:example.org resource_group_foo:example.private { + tenant 123abc-123abc-123abc-123abc + client 123abc-123abc-123abc-234xyz + subscription 123abc-123abc-123abc-563abc + secret mysecret + access private + } +} +~~~ + +## See Also + +The [Azure DNS Overview](https://docs.microsoft.com/en-us/azure/dns/dns-overview). diff --git a/plugin/azure/azure.go b/plugin/azure/azure.go new file mode 100644 index 0000000..e236a08 --- /dev/null +++ b/plugin/azure/azure.go @@ -0,0 +1,352 @@ +package azure + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + + publicdns "github.com/Azure/azure-sdk-for-go/profiles/latest/dns/mgmt/dns" + privatedns "github.com/Azure/azure-sdk-for-go/profiles/latest/privatedns/mgmt/privatedns" + "github.com/miekg/dns" +) + +type zone struct { + id string + z *file.Zone + zone string + private bool +} + +type zones map[string][]*zone + +// Azure is the core struct of the azure plugin. +type Azure struct { + zoneNames []string + publicClient publicdns.RecordSetsClient + privateClient privatedns.RecordSetsClient + upstream *upstream.Upstream + zMu sync.RWMutex + zones zones + + Next plugin.Handler + Fall fall.F +} + +// New validates the input DNS zones and initializes the Azure struct. +func New(ctx context.Context, publicClient publicdns.RecordSetsClient, privateClient privatedns.RecordSetsClient, keys map[string][]string, accessMap map[string]string) (*Azure, error) { + zones := make(map[string][]*zone, len(keys)) + names := make([]string, len(keys)) + var private bool + + for resourceGroup, znames := range keys { + for _, name := range znames { + switch accessMap[resourceGroup+name] { + case "public": + if _, err := publicClient.ListAllByDNSZone(context.Background(), resourceGroup, name, nil, ""); err != nil { + return nil, err + } + private = false + case "private": + if _, err := privateClient.ListComplete(context.Background(), resourceGroup, name, nil, ""); err != nil { + return nil, err + } + private = true + } + + fqdn := dns.Fqdn(name) + if _, ok := zones[fqdn]; !ok { + names = append(names, fqdn) + } + zones[fqdn] = append(zones[fqdn], &zone{id: resourceGroup, zone: name, private: private, z: file.NewZone(fqdn, "")}) + } + } + + return &Azure{ + publicClient: publicClient, + privateClient: privateClient, + zones: zones, + zoneNames: names, + upstream: upstream.New(), + }, nil +} + +// Run updates the zone from azure. +func (h *Azure) Run(ctx context.Context) error { + if err := h.updateZones(ctx); err != nil { + return err + } + go func() { + delay := 1 * time.Minute + timer := time.NewTimer(delay) + defer timer.Stop() + for { + timer.Reset(delay) + select { + case <-ctx.Done(): + log.Debugf("Breaking out of Azure update loop for %v: %v", h.zoneNames, ctx.Err()) + return + case <-timer.C: + if err := h.updateZones(ctx); err != nil && ctx.Err() == nil { + log.Errorf("Failed to update zones %v: %v", h.zoneNames, err) + } + } + } + }() + return nil +} + +func (h *Azure) updateZones(ctx context.Context) error { + var err error + var publicSet publicdns.RecordSetListResultPage + var privateSet privatedns.RecordSetListResultPage + errs := make([]string, 0) + for zName, z := range h.zones { + for i, hostedZone := range z { + newZ := file.NewZone(zName, "") + if hostedZone.private { + for privateSet, err = h.privateClient.List(ctx, hostedZone.id, hostedZone.zone, nil, ""); privateSet.NotDone(); err = privateSet.NextWithContext(ctx) { + updateZoneFromPrivateResourceSet(privateSet, newZ) + } + } else { + for publicSet, err = h.publicClient.ListByDNSZone(ctx, hostedZone.id, hostedZone.zone, nil, ""); publicSet.NotDone(); err = publicSet.NextWithContext(ctx) { + updateZoneFromPublicResourceSet(publicSet, newZ) + } + } + if err != nil { + errs = append(errs, fmt.Sprintf("failed to list resource records for %v from azure: %v", hostedZone.zone, err)) + } + newZ.Upstream = h.upstream + h.zMu.Lock() + (*z[i]).z = newZ + h.zMu.Unlock() + } + } + + if len(errs) != 0 { + return fmt.Errorf("errors updating zones: %v", errs) + } + return nil +} + +func updateZoneFromPublicResourceSet(recordSet publicdns.RecordSetListResultPage, newZ *file.Zone) { + for _, result := range *(recordSet.Response().Value) { + resultFqdn := *(result.RecordSetProperties.Fqdn) + resultTTL := uint32(*(result.RecordSetProperties.TTL)) + if result.RecordSetProperties.ARecords != nil { + for _, A := range *(result.RecordSetProperties.ARecords) { + a := &dns.A{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: resultTTL}, + A: net.ParseIP(*(A.Ipv4Address))} + newZ.Insert(a) + } + } + + if result.RecordSetProperties.AaaaRecords != nil { + for _, AAAA := range *(result.RecordSetProperties.AaaaRecords) { + aaaa := &dns.AAAA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: resultTTL}, + AAAA: net.ParseIP(*(AAAA.Ipv6Address))} + newZ.Insert(aaaa) + } + } + + if result.RecordSetProperties.MxRecords != nil { + for _, MX := range *(result.RecordSetProperties.MxRecords) { + mx := &dns.MX{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: resultTTL}, + Preference: uint16(*(MX.Preference)), + Mx: dns.Fqdn(*(MX.Exchange))} + newZ.Insert(mx) + } + } + + if result.RecordSetProperties.PtrRecords != nil { + for _, PTR := range *(result.RecordSetProperties.PtrRecords) { + ptr := &dns.PTR{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: resultTTL}, + Ptr: dns.Fqdn(*(PTR.Ptrdname))} + newZ.Insert(ptr) + } + } + + if result.RecordSetProperties.SrvRecords != nil { + for _, SRV := range *(result.RecordSetProperties.SrvRecords) { + srv := &dns.SRV{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: resultTTL}, + Priority: uint16(*(SRV.Priority)), + Weight: uint16(*(SRV.Weight)), + Port: uint16(*(SRV.Port)), + Target: dns.Fqdn(*(SRV.Target))} + newZ.Insert(srv) + } + } + + if result.RecordSetProperties.TxtRecords != nil { + for _, TXT := range *(result.RecordSetProperties.TxtRecords) { + txt := &dns.TXT{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: resultTTL}, + Txt: *(TXT.Value)} + newZ.Insert(txt) + } + } + + if result.RecordSetProperties.NsRecords != nil { + for _, NS := range *(result.RecordSetProperties.NsRecords) { + ns := &dns.NS{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: resultTTL}, + Ns: *(NS.Nsdname)} + newZ.Insert(ns) + } + } + + if result.RecordSetProperties.SoaRecord != nil { + SOA := result.RecordSetProperties.SoaRecord + soa := &dns.SOA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: resultTTL}, + Minttl: uint32(*(SOA.MinimumTTL)), + Expire: uint32(*(SOA.ExpireTime)), + Retry: uint32(*(SOA.RetryTime)), + Refresh: uint32(*(SOA.RefreshTime)), + Serial: uint32(*(SOA.SerialNumber)), + Mbox: dns.Fqdn(*(SOA.Email)), + Ns: *(SOA.Host)} + newZ.Insert(soa) + } + + if result.RecordSetProperties.CnameRecord != nil { + CNAME := result.RecordSetProperties.CnameRecord.Cname + cname := &dns.CNAME{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: resultTTL}, + Target: dns.Fqdn(*CNAME)} + newZ.Insert(cname) + } + } +} + +func updateZoneFromPrivateResourceSet(recordSet privatedns.RecordSetListResultPage, newZ *file.Zone) { + for _, result := range *(recordSet.Response().Value) { + resultFqdn := *(result.RecordSetProperties.Fqdn) + resultTTL := uint32(*(result.RecordSetProperties.TTL)) + if result.RecordSetProperties.ARecords != nil { + for _, A := range *(result.RecordSetProperties.ARecords) { + a := &dns.A{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: resultTTL}, + A: net.ParseIP(*(A.Ipv4Address))} + newZ.Insert(a) + } + } + if result.RecordSetProperties.AaaaRecords != nil { + for _, AAAA := range *(result.RecordSetProperties.AaaaRecords) { + aaaa := &dns.AAAA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: resultTTL}, + AAAA: net.ParseIP(*(AAAA.Ipv6Address))} + newZ.Insert(aaaa) + } + } + + if result.RecordSetProperties.MxRecords != nil { + for _, MX := range *(result.RecordSetProperties.MxRecords) { + mx := &dns.MX{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: resultTTL}, + Preference: uint16(*(MX.Preference)), + Mx: dns.Fqdn(*(MX.Exchange))} + newZ.Insert(mx) + } + } + + if result.RecordSetProperties.PtrRecords != nil { + for _, PTR := range *(result.RecordSetProperties.PtrRecords) { + ptr := &dns.PTR{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: resultTTL}, + Ptr: dns.Fqdn(*(PTR.Ptrdname))} + newZ.Insert(ptr) + } + } + + if result.RecordSetProperties.SrvRecords != nil { + for _, SRV := range *(result.RecordSetProperties.SrvRecords) { + srv := &dns.SRV{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: resultTTL}, + Priority: uint16(*(SRV.Priority)), + Weight: uint16(*(SRV.Weight)), + Port: uint16(*(SRV.Port)), + Target: dns.Fqdn(*(SRV.Target))} + newZ.Insert(srv) + } + } + + if result.RecordSetProperties.TxtRecords != nil { + for _, TXT := range *(result.RecordSetProperties.TxtRecords) { + txt := &dns.TXT{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: resultTTL}, + Txt: *(TXT.Value)} + newZ.Insert(txt) + } + } + + if result.RecordSetProperties.SoaRecord != nil { + SOA := result.RecordSetProperties.SoaRecord + soa := &dns.SOA{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: resultTTL}, + Minttl: uint32(*(SOA.MinimumTTL)), + Expire: uint32(*(SOA.ExpireTime)), + Retry: uint32(*(SOA.RetryTime)), + Refresh: uint32(*(SOA.RefreshTime)), + Serial: uint32(*(SOA.SerialNumber)), + Mbox: dns.Fqdn(*(SOA.Email)), + Ns: dns.Fqdn(*(SOA.Host))} + newZ.Insert(soa) + } + + if result.RecordSetProperties.CnameRecord != nil { + CNAME := result.RecordSetProperties.CnameRecord.Cname + cname := &dns.CNAME{Hdr: dns.RR_Header{Name: resultFqdn, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: resultTTL}, + Target: dns.Fqdn(*CNAME)} + newZ.Insert(cname) + } + } +} + +// ServeDNS implements the plugin.Handler interface. +func (h *Azure) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + + zone := plugin.Zones(h.zoneNames).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + zones, ok := h.zones[zone] // ok true if we are authoritative for the zone. + if !ok || zones == nil { + return dns.RcodeServerFailure, nil + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + var result file.Result + for _, z := range zones { + h.zMu.RLock() + m.Answer, m.Ns, m.Extra, result = z.z.Lookup(ctx, state, qname) + h.zMu.RUnlock() + + // record type exists for this name (NODATA). + if len(m.Answer) != 0 || result == file.NoData { + break + } + } + + if len(m.Answer) == 0 && result != file.NoData && h.Fall.Through(qname) { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + switch result { + case file.Success: + case file.NoData: + case file.NameError: + m.Rcode = dns.RcodeNameError + case file.Delegation: + m.Authoritative = false + case file.ServerFailure: + return dns.RcodeServerFailure, nil + } + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +// Name implements plugin.Handler.Name. +func (h *Azure) Name() string { return "azure" } diff --git a/plugin/azure/azure_test.go b/plugin/azure/azure_test.go new file mode 100644 index 0000000..0178300 --- /dev/null +++ b/plugin/azure/azure_test.go @@ -0,0 +1,180 @@ +package azure + +import ( + "context" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var demoAzure = Azure{ + Next: testHandler(), + Fall: fall.Zero, + zoneNames: []string{"example.org.", "www.example.org.", "example.org.", "sample.example.org."}, + zones: testZones(), +} + +func testZones() zones { + zones := make(map[string][]*zone) + zones["example.org."] = append(zones["example.org."], &zone{zone: "example.org."}) + newZ := file.NewZone("example.org.", "") + + for _, rr := range []string{ + "example.org. 300 IN A 1.2.3.4", + "example.org. 300 IN AAAA 2001:db8:85a3::8a2e:370:7334", + "www.example.org. 300 IN A 1.2.3.4", + "www.example.org. 300 IN A 1.2.3.4", + "org. 172800 IN NS ns3-06.azure-dns.org.", + "org. 300 IN SOA ns1-06.azure-dns.com. azuredns-hostmaster.microsoft.com. 1 3600 300 2419200 300", + "cname.example.org. 300 IN CNAME example.org", + "mail.example.org. 300 IN MX 10 mailserver.example.com", + "ptr.example.org. 300 IN PTR www.ptr-example.com", + "example.org. 300 IN SRV 1 10 5269 srv-1.example.com.", + "example.org. 300 IN SRV 1 10 5269 srv-2.example.com.", + "txt.example.org. 300 IN TXT \"TXT for example.org\"", + } { + r, _ := dns.NewRR(rr) + newZ.Insert(r) + } + zones["example.org."][0].z = newZ + return zones +} + +func testHandler() test.HandlerFunc { + return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + m := new(dns.Msg) + rcode := dns.RcodeServerFailure + if qname == "example.gov." { // No records match, test fallthrough. + m.SetReply(r) + rr := test.A("example.gov. 300 IN A 2.4.6.8") + m.Answer = []dns.RR{rr} + m.Authoritative = true + rcode = dns.RcodeSuccess + } + m.SetRcode(r, rcode) + w.WriteMsg(m) + return rcode, nil + } +} + +func TestAzure(t *testing.T) { + tests := []struct { + qname string + qtype uint16 + wantRetCode int + wantAnswer []string + wantMsgRCode int + wantNS []string + expectedErr error + }{ + { + qname: "example.org.", + qtype: dns.TypeA, + wantAnswer: []string{"example.org. 300 IN A 1.2.3.4"}, + }, + { + qname: "example.org", + qtype: dns.TypeAAAA, + wantAnswer: []string{"example.org. 300 IN AAAA 2001:db8:85a3::8a2e:370:7334"}, + }, + { + qname: "example.org", + qtype: dns.TypeSOA, + wantAnswer: []string{"org. 300 IN SOA ns1-06.azure-dns.com. azuredns-hostmaster.microsoft.com. 1 3600 300 2419200 300"}, + }, + { + qname: "badexample.com", + qtype: dns.TypeA, + wantRetCode: dns.RcodeServerFailure, + wantMsgRCode: dns.RcodeServerFailure, + }, + { + qname: "example.gov", + qtype: dns.TypeA, + wantAnswer: []string{"example.gov. 300 IN A 2.4.6.8"}, + }, + { + qname: "example.org", + qtype: dns.TypeSRV, + wantAnswer: []string{"example.org. 300 IN SRV 1 10 5269 srv-1.example.com.", "example.org. 300 IN SRV 1 10 5269 srv-2.example.com."}, + }, + { + qname: "cname.example.org.", + qtype: dns.TypeCNAME, + wantAnswer: []string{"cname.example.org. 300 IN CNAME example.org."}, + }, + { + qname: "cname.example.org.", + qtype: dns.TypeA, + wantAnswer: []string{"cname.example.org. 300 IN CNAME example.org.", "example.org. 300 IN A 1.2.3.4"}, + }, + { + qname: "mail.example.org.", + qtype: dns.TypeMX, + wantAnswer: []string{"mail.example.org. 300 IN MX 10 mailserver.example.com."}, + }, + { + qname: "ptr.example.org.", + qtype: dns.TypePTR, + wantAnswer: []string{"ptr.example.org. 300 IN PTR www.ptr-example.com."}, + }, + { + qname: "txt.example.org.", + qtype: dns.TypeTXT, + wantAnswer: []string{"txt.example.org. 300 IN TXT \"TXT for example.org\""}, + }, + } + + for ti, tc := range tests { + req := new(dns.Msg) + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := demoAzure.ServeDNS(context.Background(), rec, req) + + if err != tc.expectedErr { + t.Fatalf("Test %d: Expected error %v, but got %v", ti, tc.expectedErr, err) + } + + if code != tc.wantRetCode { + t.Fatalf("Test %d: Expected returned status code %s, but got %s", ti, dns.RcodeToString[tc.wantRetCode], dns.RcodeToString[code]) + } + + if tc.wantMsgRCode != rec.Msg.Rcode { + t.Errorf("Test %d: Unexpected msg status code. Want: %s, got: %s", ti, dns.RcodeToString[tc.wantMsgRCode], dns.RcodeToString[rec.Msg.Rcode]) + } + + if len(tc.wantAnswer) != len(rec.Msg.Answer) { + t.Errorf("Test %d: Unexpected number of Answers. Want: %d, got: %d", ti, len(tc.wantAnswer), len(rec.Msg.Answer)) + } else { + for i, gotAnswer := range rec.Msg.Answer { + if gotAnswer.String() != tc.wantAnswer[i] { + t.Errorf("Test %d: Unexpected answer.\nWant:\n\t%s\nGot:\n\t%s", ti, tc.wantAnswer[i], gotAnswer) + } + } + } + + if len(tc.wantNS) != len(rec.Msg.Ns) { + t.Errorf("Test %d: Unexpected NS number. Want: %d, got: %d", ti, len(tc.wantNS), len(rec.Msg.Ns)) + } else { + for i, ns := range rec.Msg.Ns { + got, ok := ns.(*dns.SOA) + if !ok { + t.Errorf("Test %d: Unexpected NS type. Want: SOA, got: %v", ti, reflect.TypeOf(got)) + } + if got.String() != tc.wantNS[i] { + t.Errorf("Test %d: Unexpected NS.\nWant: %v\nGot: %v", ti, tc.wantNS[i], got) + } + } + } + } +} diff --git a/plugin/azure/setup.go b/plugin/azure/setup.go new file mode 100644 index 0000000..6cabe05 --- /dev/null +++ b/plugin/azure/setup.go @@ -0,0 +1,144 @@ +package azure + +import ( + "context" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/fall" + clog "github.com/coredns/coredns/plugin/pkg/log" + + publicAzureDNS "github.com/Azure/azure-sdk-for-go/profiles/latest/dns/mgmt/dns" + privateAzureDNS "github.com/Azure/azure-sdk-for-go/profiles/latest/privatedns/mgmt/privatedns" + azurerest "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/azure/auth" +) + +var log = clog.NewWithPlugin("azure") + +func init() { plugin.Register("azure", setup) } + +func setup(c *caddy.Controller) error { + env, keys, accessMap, fall, err := parse(c) + if err != nil { + return plugin.Error("azure", err) + } + ctx, cancel := context.WithCancel(context.Background()) + + publicDNSClient := publicAzureDNS.NewRecordSetsClient(env.Values[auth.SubscriptionID]) + if publicDNSClient.Authorizer, err = env.GetAuthorizer(); err != nil { + cancel() + return plugin.Error("azure", err) + } + + privateDNSClient := privateAzureDNS.NewRecordSetsClient(env.Values[auth.SubscriptionID]) + if privateDNSClient.Authorizer, err = env.GetAuthorizer(); err != nil { + cancel() + return plugin.Error("azure", err) + } + + h, err := New(ctx, publicDNSClient, privateDNSClient, keys, accessMap) + if err != nil { + cancel() + return plugin.Error("azure", err) + } + h.Fall = fall + if err := h.Run(ctx); err != nil { + cancel() + return plugin.Error("azure", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + h.Next = next + return h + }) + c.OnShutdown(func() error { cancel(); return nil }) + return nil +} + +func parse(c *caddy.Controller) (auth.EnvironmentSettings, map[string][]string, map[string]string, fall.F, error) { + resourceGroupMapping := map[string][]string{} + accessMap := map[string]string{} + resourceGroupSet := map[string]struct{}{} + azureEnv := azurerest.PublicCloud + env := auth.EnvironmentSettings{Values: map[string]string{}} + + var fall fall.F + var access string + var resourceGroup string + var zoneName string + + for c.Next() { + args := c.RemainingArgs() + + for i := 0; i < len(args); i++ { + parts := strings.SplitN(args[i], ":", 2) + if len(parts) != 2 { + return env, resourceGroupMapping, accessMap, fall, c.Errf("invalid resource group/zone: %q", args[i]) + } + resourceGroup, zoneName = parts[0], parts[1] + if resourceGroup == "" || zoneName == "" { + return env, resourceGroupMapping, accessMap, fall, c.Errf("invalid resource group/zone: %q", args[i]) + } + if _, ok := resourceGroupSet[resourceGroup+zoneName]; ok { + return env, resourceGroupMapping, accessMap, fall, c.Errf("conflicting zone: %q", args[i]) + } + + resourceGroupSet[resourceGroup+zoneName] = struct{}{} + accessMap[resourceGroup+zoneName] = "public" + resourceGroupMapping[resourceGroup] = append(resourceGroupMapping[resourceGroup], zoneName) + } + + for c.NextBlock() { + switch c.Val() { + case "subscription": + if !c.NextArg() { + return env, resourceGroupMapping, accessMap, fall, c.ArgErr() + } + env.Values[auth.SubscriptionID] = c.Val() + case "tenant": + if !c.NextArg() { + return env, resourceGroupMapping, accessMap, fall, c.ArgErr() + } + env.Values[auth.TenantID] = c.Val() + case "client": + if !c.NextArg() { + return env, resourceGroupMapping, accessMap, fall, c.ArgErr() + } + env.Values[auth.ClientID] = c.Val() + case "secret": + if !c.NextArg() { + return env, resourceGroupMapping, accessMap, fall, c.ArgErr() + } + env.Values[auth.ClientSecret] = c.Val() + case "environment": + if !c.NextArg() { + return env, resourceGroupMapping, accessMap, fall, c.ArgErr() + } + var err error + if azureEnv, err = azurerest.EnvironmentFromName(c.Val()); err != nil { + return env, resourceGroupMapping, accessMap, fall, c.Errf("cannot set azure environment: %q", err.Error()) + } + case "fallthrough": + fall.SetZonesFromArgs(c.RemainingArgs()) + case "access": + if !c.NextArg() { + return env, resourceGroupMapping, accessMap, fall, c.ArgErr() + } + access = c.Val() + if access != "public" && access != "private" { + return env, resourceGroupMapping, accessMap, fall, c.Errf("invalid access value: can be public/private, found: %s", access) + } + accessMap[resourceGroup+zoneName] = access + default: + return env, resourceGroupMapping, accessMap, fall, c.Errf("unknown property: %q", c.Val()) + } + } + } + + env.Values[auth.Resource] = azureEnv.ResourceManagerEndpoint + env.Environment = azureEnv + return env, resourceGroupMapping, accessMap, fall, nil +} diff --git a/plugin/azure/setup_test.go b/plugin/azure/setup_test.go new file mode 100644 index 0000000..c6c26b1 --- /dev/null +++ b/plugin/azure/setup_test.go @@ -0,0 +1,71 @@ +package azure + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + body string + expectedError bool + }{ + {`azure`, false}, + {`azure :`, true}, + {`azure resource_set:zone`, false}, + {`azure resource_set:zone { + tenant +}`, true}, + {`azure resource_set:zone { + tenant abc +}`, false}, + {`azure resource_set:zone { + client +}`, true}, + {`azure resource_set:zone { + client abc +}`, false}, + {`azure resource_set:zone { + subscription +}`, true}, + {`azure resource_set:zone { + subscription abc +}`, false}, + {`azure resource_set:zone { + foo +}`, true}, + {`azure resource_set:zone { + tenant tenant_id + client client_id + secret client_secret + subscription subscription_id + access public +}`, false}, + {`azure resource_set:zone { + fallthrough +}`, false}, + {`azure resource_set:zone { + environment AZUREPUBLICCLOUD + }`, false}, + {`azure resource_set:zone resource_set:zone { + fallthrough + }`, true}, + {`azure resource_set:zone,zone2 { + access private + }`, false}, + {`azure resource-set:zone { + access public + }`, false}, + {`azure resource-set:zone { + access foo + }`, true}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.body) + if _, _, _, _, err := parse(c); (err == nil) == test.expectedError { + t.Fatalf("Unexpected errors: %v in test: %d\n\t%s", err, i, test.body) + } + } +} diff --git a/plugin/backend.go b/plugin/backend.go new file mode 100644 index 0000000..a0217c9 --- /dev/null +++ b/plugin/backend.go @@ -0,0 +1,40 @@ +package plugin + +import ( + "context" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServiceBackend defines a (dynamic) backend that returns a slice of service definitions. +type ServiceBackend interface { + // Services communicates with the backend to retrieve the service definitions. Exact indicates + // on exact match should be returned. + Services(ctx context.Context, state request.Request, exact bool, opt Options) ([]msg.Service, error) + + // Reverse communicates with the backend to retrieve service definition based on a IP address + // instead of a name. I.e. a reverse DNS lookup. + Reverse(ctx context.Context, state request.Request, exact bool, opt Options) ([]msg.Service, error) + + // Lookup is used to find records else where. + Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) + + // Returns _all_ services that matches a certain name. + // Note: it does not implement a specific service. + Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error) + + // IsNameError returns true if err indicated a record not found condition + IsNameError(err error) bool + + // Serial returns a SOA serial number to construct a SOA record. + Serial(state request.Request) uint32 + + // MinTTL returns the minimum TTL to be used in the SOA record. + MinTTL(state request.Request) uint32 +} + +// Options are extra options that can be specified for a lookup. +type Options struct{} diff --git a/plugin/backend_lookup.go b/plugin/backend_lookup.go new file mode 100644 index 0000000..0887bb4 --- /dev/null +++ b/plugin/backend_lookup.go @@ -0,0 +1,560 @@ +package plugin + +import ( + "context" + "fmt" + "math" + "net" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// A returns A records from Backend or an error. +func A(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) { + services, err := checkForApex(ctx, b, zone, state, opt) + if err != nil { + return nil, false, err + } + + dup := make(map[string]struct{}) + + for _, serv := range services { + what, ip := serv.HostType() + + switch what { + case dns.TypeCNAME: + if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) { + // x CNAME x is a direct loop, don't add those + // in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results + continue + } + + newRecord := serv.NewCNAME(state.QName(), serv.Host) + if len(previousRecords) > 7 { + // don't add it, and just continue + continue + } + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { + continue + } + if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) { + state1 := state.NewWithQuestion(serv.Host, state.QType()) + state1.Zone = zone + nextRecords, tc, err := A(ctx, b, zone, state1, append(previousRecords, newRecord), opt) + + if err == nil { + // Not only have we found something we should add the CNAME and the IP addresses. + if len(nextRecords) > 0 { + records = append(records, newRecord) + records = append(records, nextRecords...) + } + } + if tc { + truncated = true + } + continue + } + // This means we can not complete the CNAME, try to look else where. + target := newRecord.Target + // Lookup + m1, e1 := b.Lookup(ctx, state, target, state.QType()) + if e1 != nil { + continue + } + if m1.Truncated { + truncated = true + } + // Len(m1.Answer) > 0 here is well? + records = append(records, newRecord) + records = append(records, m1.Answer...) + continue + + case dns.TypeA: + if _, ok := dup[serv.Host]; !ok { + dup[serv.Host] = struct{}{} + records = append(records, serv.NewA(state.QName(), ip)) + } + + case dns.TypeAAAA: + // nada + } + } + return records, truncated, nil +} + +// AAAA returns AAAA records from Backend or an error. +func AAAA(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) { + services, err := checkForApex(ctx, b, zone, state, opt) + if err != nil { + return nil, false, err + } + + dup := make(map[string]struct{}) + + for _, serv := range services { + what, ip := serv.HostType() + + switch what { + case dns.TypeCNAME: + // Try to resolve as CNAME if it's not an IP, but only if we don't create loops. + if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) { + // x CNAME x is a direct loop, don't add those + // in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results + continue + } + + newRecord := serv.NewCNAME(state.QName(), serv.Host) + if len(previousRecords) > 7 { + // don't add it, and just continue + continue + } + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { + continue + } + if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) { + state1 := state.NewWithQuestion(serv.Host, state.QType()) + state1.Zone = zone + nextRecords, tc, err := AAAA(ctx, b, zone, state1, append(previousRecords, newRecord), opt) + + if err == nil { + // Not only have we found something we should add the CNAME and the IP addresses. + if len(nextRecords) > 0 { + records = append(records, newRecord) + records = append(records, nextRecords...) + } + } + if tc { + truncated = true + } + continue + } + // This means we can not complete the CNAME, try to look else where. + target := newRecord.Target + m1, e1 := b.Lookup(ctx, state, target, state.QType()) + if e1 != nil { + continue + } + if m1.Truncated { + truncated = true + } + // Len(m1.Answer) > 0 here is well? + records = append(records, newRecord) + records = append(records, m1.Answer...) + continue + // both here again + + case dns.TypeA: + // nada + + case dns.TypeAAAA: + if _, ok := dup[serv.Host]; !ok { + dup[serv.Host] = struct{}{} + records = append(records, serv.NewAAAA(state.QName(), ip)) + } + } + } + return records, truncated, nil +} + +// SRV returns SRV records from the Backend. +// If the Target is not a name but an IP address, a name is created on the fly. +func SRV(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) { + services, err := b.Services(ctx, state, false, opt) + if err != nil { + return nil, nil, err + } + + dup := make(map[item]struct{}) + lookup := make(map[string]struct{}) + + // Looping twice to get the right weight vs priority. This might break because we may drop duplicate SRV records latter on. + w := make(map[int]int) + for _, serv := range services { + weight := 100 + if serv.Weight != 0 { + weight = serv.Weight + } + if _, ok := w[serv.Priority]; !ok { + w[serv.Priority] = weight + continue + } + w[serv.Priority] += weight + } + for _, serv := range services { + // Don't add the entry if the port is -1 (invalid). The kubernetes plugin uses port -1 when a service/endpoint + // does not have any declared ports. + if serv.Port == -1 { + continue + } + w1 := 100.0 / float64(w[serv.Priority]) + if serv.Weight == 0 { + w1 *= 100 + } else { + w1 *= float64(serv.Weight) + } + weight := uint16(math.Floor(w1)) + // weight should be at least 1 + if weight == 0 { + weight = 1 + } + + what, ip := serv.HostType() + + switch what { + case dns.TypeCNAME: + srv := serv.NewSRV(state.QName(), weight) + records = append(records, srv) + + if _, ok := lookup[srv.Target]; ok { + break + } + + lookup[srv.Target] = struct{}{} + + if !dns.IsSubDomain(zone, srv.Target) { + m1, e1 := b.Lookup(ctx, state, srv.Target, dns.TypeA) + if e1 == nil { + extra = append(extra, m1.Answer...) + } + + m1, e1 = b.Lookup(ctx, state, srv.Target, dns.TypeAAAA) + if e1 == nil { + // If we have seen CNAME's we *assume* that they are already added. + for _, a := range m1.Answer { + if _, ok := a.(*dns.CNAME); !ok { + extra = append(extra, a) + } + } + } + break + } + // Internal name, we should have some info on them, either v4 or v6 + // Clients expect a complete answer, because we are a recursor in their view. + state1 := state.NewWithQuestion(srv.Target, dns.TypeA) + addr, _, e1 := A(ctx, b, zone, state1, nil, opt) + if e1 == nil { + extra = append(extra, addr...) + } + // TODO(miek): AAAA as well here. + + case dns.TypeA, dns.TypeAAAA: + addr := serv.Host + serv.Host = msg.Domain(serv.Key) + srv := serv.NewSRV(state.QName(), weight) + + if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok { + records = append(records, srv) + } + + if ok := isDuplicate(dup, srv.Target, addr, 0); !ok { + extra = append(extra, newAddress(serv, srv.Target, ip, what)) + } + } + } + return records, extra, nil +} + +// MX returns MX records from the Backend. If the Target is not a name but an IP address, a name is created on the fly. +func MX(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) { + services, err := b.Services(ctx, state, false, opt) + if err != nil { + return nil, nil, err + } + + dup := make(map[item]struct{}) + lookup := make(map[string]struct{}) + for _, serv := range services { + if !serv.Mail { + continue + } + what, ip := serv.HostType() + switch what { + case dns.TypeCNAME: + mx := serv.NewMX(state.QName()) + records = append(records, mx) + if _, ok := lookup[mx.Mx]; ok { + break + } + + lookup[mx.Mx] = struct{}{} + + if !dns.IsSubDomain(zone, mx.Mx) { + m1, e1 := b.Lookup(ctx, state, mx.Mx, dns.TypeA) + if e1 == nil { + extra = append(extra, m1.Answer...) + } + + m1, e1 = b.Lookup(ctx, state, mx.Mx, dns.TypeAAAA) + if e1 == nil { + // If we have seen CNAME's we *assume* that they are already added. + for _, a := range m1.Answer { + if _, ok := a.(*dns.CNAME); !ok { + extra = append(extra, a) + } + } + } + break + } + // Internal name + state1 := state.NewWithQuestion(mx.Mx, dns.TypeA) + addr, _, e1 := A(ctx, b, zone, state1, nil, opt) + if e1 == nil { + extra = append(extra, addr...) + } + // TODO(miek): AAAA as well here. + + case dns.TypeA, dns.TypeAAAA: + addr := serv.Host + serv.Host = msg.Domain(serv.Key) + mx := serv.NewMX(state.QName()) + + if ok := isDuplicate(dup, mx.Mx, "", mx.Preference); !ok { + records = append(records, mx) + } + // Fake port to be 0 for address... + if ok := isDuplicate(dup, serv.Host, addr, 0); !ok { + extra = append(extra, newAddress(serv, serv.Host, ip, what)) + } + } + } + return records, extra, nil +} + +// CNAME returns CNAME records from the backend or an error. +func CNAME(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) { + services, err := b.Services(ctx, state, true, opt) + if err != nil { + return nil, err + } + + if len(services) > 0 { + serv := services[0] + if ip := net.ParseIP(serv.Host); ip == nil { + records = append(records, serv.NewCNAME(state.QName(), serv.Host)) + } + } + return records, nil +} + +// TXT returns TXT records from Backend or an error. +func TXT(ctx context.Context, b ServiceBackend, zone string, state request.Request, previousRecords []dns.RR, opt Options) (records []dns.RR, truncated bool, err error) { + services, err := b.Services(ctx, state, false, opt) + if err != nil { + return nil, false, err + } + + dup := make(map[string]struct{}) + + for _, serv := range services { + what, _ := serv.HostType() + + switch what { + case dns.TypeCNAME: + if Name(state.Name()).Matches(dns.Fqdn(serv.Host)) { + // x CNAME x is a direct loop, don't add those + // in etcd/skydns w.x CNAME x is also direct loop due to the "recursive" nature of search results + continue + } + + newRecord := serv.NewCNAME(state.QName(), serv.Host) + if len(previousRecords) > 7 { + // don't add it, and just continue + continue + } + if dnsutil.DuplicateCNAME(newRecord, previousRecords) { + continue + } + if dns.IsSubDomain(zone, dns.Fqdn(serv.Host)) { + state1 := state.NewWithQuestion(serv.Host, state.QType()) + state1.Zone = zone + nextRecords, tc, err := TXT(ctx, b, zone, state1, append(previousRecords, newRecord), opt) + if tc { + truncated = true + } + if err == nil { + // Not only have we found something we should add the CNAME and the IP addresses. + if len(nextRecords) > 0 { + records = append(records, newRecord) + records = append(records, nextRecords...) + } + } + continue + } + // This means we can not complete the CNAME, try to look else where. + target := newRecord.Target + // Lookup + m1, e1 := b.Lookup(ctx, state, target, state.QType()) + if e1 != nil { + continue + } + // Len(m1.Answer) > 0 here is well? + records = append(records, newRecord) + records = append(records, m1.Answer...) + continue + + case dns.TypeTXT: + if _, ok := dup[serv.Text]; !ok { + dup[serv.Text] = struct{}{} + records = append(records, serv.NewTXT(state.QName())) + } + } + } + + return records, truncated, nil +} + +// PTR returns the PTR records from the backend, only services that have a domain name as host are included. +func PTR(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records []dns.RR, err error) { + services, err := b.Reverse(ctx, state, true, opt) + if err != nil { + return nil, err + } + + dup := make(map[string]struct{}) + + for _, serv := range services { + if ip := net.ParseIP(serv.Host); ip == nil { + if _, ok := dup[serv.Host]; !ok { + dup[serv.Host] = struct{}{} + records = append(records, serv.NewPTR(state.QName(), serv.Host)) + } + } + } + return records, nil +} + +// NS returns NS records from the backend +func NS(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) (records, extra []dns.RR, err error) { + // NS record for this zone live in a special place, ns.dns.<zone>. Fake our lookup. + // only a tad bit fishy... + old := state.QName() + + state.Clear() + state.Req.Question[0].Name = dnsutil.Join("ns.dns.", zone) + services, err := b.Services(ctx, state, false, opt) + if err != nil { + return nil, nil, err + } + // ... and reset + state.Req.Question[0].Name = old + + seen := map[string]bool{} + + for _, serv := range services { + what, ip := serv.HostType() + switch what { + case dns.TypeCNAME: + return nil, nil, fmt.Errorf("NS record must be an IP address: %s", serv.Host) + + case dns.TypeA, dns.TypeAAAA: + serv.Host = msg.Domain(serv.Key) + ns := serv.NewNS(state.QName()) + extra = append(extra, newAddress(serv, ns.Ns, ip, what)) + if _, ok := seen[ns.Ns]; ok { + continue + } + seen[ns.Ns] = true + records = append(records, ns) + } + } + return records, extra, nil +} + +// SOA returns a SOA record from the backend. +func SOA(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]dns.RR, error) { + minTTL := b.MinTTL(state) + ttl := uint32(300) + if minTTL < ttl { + ttl = minTTL + } + + header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: ttl, Class: dns.ClassINET} + + Mbox := dnsutil.Join(hostmaster, zone) + Ns := dnsutil.Join("ns.dns", zone) + + soa := &dns.SOA{Hdr: header, + Mbox: Mbox, + Ns: Ns, + Serial: b.Serial(state), + Refresh: 7200, + Retry: 1800, + Expire: 86400, + Minttl: minTTL, + } + return []dns.RR{soa}, nil +} + +// BackendError writes an error response to the client. +func BackendError(ctx context.Context, b ServiceBackend, zone string, rcode int, state request.Request, err error, opt Options) (int, error) { + m := new(dns.Msg) + m.SetRcode(state.Req, rcode) + m.Authoritative = true + m.Ns, _ = SOA(ctx, b, zone, state, opt) + + state.W.WriteMsg(m) + // Return success as the rcode to signal we have written to the client. + return dns.RcodeSuccess, err +} + +func newAddress(s msg.Service, name string, ip net.IP, what uint16) dns.RR { + hdr := dns.RR_Header{Name: name, Rrtype: what, Class: dns.ClassINET, Ttl: s.TTL} + + if what == dns.TypeA { + return &dns.A{Hdr: hdr, A: ip} + } + // Should always be dns.TypeAAAA + return &dns.AAAA{Hdr: hdr, AAAA: ip} +} + +// checkForApex checks the special apex.dns directory for records that will be returned as A or AAAA. +func checkForApex(ctx context.Context, b ServiceBackend, zone string, state request.Request, opt Options) ([]msg.Service, error) { + if state.Name() != zone { + return b.Services(ctx, state, false, opt) + } + + // If the zone name itself is queried we fake the query to search for a special entry + // this is equivalent to the NS search code. + old := state.QName() + state.Clear() + state.Req.Question[0].Name = dnsutil.Join("apex.dns", zone) + + services, err := b.Services(ctx, state, false, opt) + if err == nil { + state.Req.Question[0].Name = old + return services, err + } + + state.Req.Question[0].Name = old + return b.Services(ctx, state, false, opt) +} + +// item holds records. +type item struct { + name string // name of the record (either owner or something else unique). + port uint16 // port of the record (used for address records, A and AAAA). + addr string // address of the record (A and AAAA). +} + +// isDuplicate uses m to see if the combo (name, addr, port) already exists. If it does +// not exist already IsDuplicate will also add the record to the map. +func isDuplicate(m map[item]struct{}, name, addr string, port uint16) bool { + if addr != "" { + _, ok := m[item{name, 0, addr}] + if !ok { + m[item{name, 0, addr}] = struct{}{} + } + return ok + } + _, ok := m[item{name, port, ""}] + if !ok { + m[item{name, port, ""}] = struct{}{} + } + return ok +} + +const hostmaster = "hostmaster" diff --git a/plugin/bind/README.md b/plugin/bind/README.md new file mode 100644 index 0000000..1c0f0c5 --- /dev/null +++ b/plugin/bind/README.md @@ -0,0 +1,113 @@ +# bind + +## Name + +*bind* - overrides the host to which the server should bind. + +## Description + +Normally, the listener binds to the wildcard host. However, you may want the listener to bind to +another IP instead. + +If several addresses are provided, a listener will be open on each of the IP provided. + +Each address has to be an IP or name of one of the interfaces of the host. Bind by interface name, binds to the IPs on that interface at the time of startup or reload (reload will happen with a SIGHUP or if the config file changes). + +If the given argument is an interface name, and that interface has several IP addresses, CoreDNS will listen on all of the interface IP addresses (including IPv4 and IPv6), except for IPv6 link-local addresses on that interface. + +## Syntax + +In its basic form, a simple bind uses this syntax: + +~~~ txt +bind ADDRESS|IFACE ... +~~~ + +You can also exclude some addresses with their IP address or interface name in expanded syntax: + +~~~ +bind ADDRESS|IFACE ... { + except ADDRESS|IFACE ... +} +~~~ + + + +* **ADDRESS|IFACE** is an IP address or interface name to bind to. +When several addresses are provided a listener will be opened on each of the addresses. Please read the *Description* for more details. +* `except`, excludes interfaces or IP addresses to bind to. `except` option only excludes addresses for the current `bind` directive if multiple `bind` directives are used in the same server block. +## Examples + +To make your socket accessible only to that machine, bind to IP 127.0.0.1 (localhost): + +~~~ corefile +. { + bind 127.0.0.1 +} +~~~ + +To allow processing DNS requests only local host on both IPv4 and IPv6 stacks, use the syntax: + +~~~ corefile +. { + bind 127.0.0.1 ::1 +} +~~~ + +If the configuration comes up with several *bind* plugins, all addresses are consolidated together: +The following sample is equivalent to the preceding: + +~~~ corefile +. { + bind 127.0.0.1 + bind ::1 +} +~~~ + +The following server block, binds on localhost with its interface name (both "127.0.0.1" and "::1"): + +~~~ corefile +. { + bind lo +} +~~~ + +You can exclude some addresses by their IP or interface name (The following will only listen on `::1` or whatever addresses have been assigned to the `lo` interface): + +~~~ corefile +. { + bind lo { + except 127.0.0.1 + } +} +~~~ + +## Bugs + +### Avoiding Listener Contention + +TL;DR, When adding the _bind_ plugin to a server block, it must also be added to all other server blocks that listen on the same port. + +When more than one server block is configured to listen to a common port, those server blocks must either +all use the _bind_ plugin, or all use default binding (no _bind_ plugin). Note that "port" here refers the TCP/UDP port that +a server block is configured to serve (default 53) - not a network interface. For two server blocks listening on the same port, +if one uses the bind plugin and the other does not, two separate listeners will be created that will contend for serving +packets destined to the same address. Doing so will result in unpredictable behavior (requests may be randomly +served by either server). This happens because *without* the *bind* plugin, a server will bind to all +interfaces, and this will collide with another server if it's using *bind* to listen to an address +on the same port. For example, the following creates two servers that both listen on 127.0.0.1:53, +which would result in unpredictable behavior for queries in `a.bad.example.com`: + +``` +a.bad.example.com { + bind 127.0.0.1 + forward . 1.2.3.4 +} + +bad.example.com { + forward . 5.6.7.8 +} +``` + +Also on MacOS there is an (open) bug where this doesn't work properly. See +<https://github.com/miekg/dns/issues/724> for details, but no solution. diff --git a/plugin/bind/bind.go b/plugin/bind/bind.go new file mode 100644 index 0000000..cada8fa --- /dev/null +++ b/plugin/bind/bind.go @@ -0,0 +1,17 @@ +// Package bind allows binding to a specific interface instead of bind to all of them. +package bind + +import ( + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("bind", setup) } + +type bind struct { + Next plugin.Handler + addrs []string + except []string +} + +// Name implements plugin.Handler. +func (b *bind) Name() string { return "bind" } diff --git a/plugin/bind/log_test.go b/plugin/bind/log_test.go new file mode 100644 index 0000000..4ee3ffc --- /dev/null +++ b/plugin/bind/log_test.go @@ -0,0 +1,5 @@ +package bind + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/bind/setup.go b/plugin/bind/setup.go new file mode 100644 index 0000000..1bd3975 --- /dev/null +++ b/plugin/bind/setup.go @@ -0,0 +1,102 @@ +package bind + +import ( + "errors" + "fmt" + "net" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/log" + + "k8s.io/utils/strings/slices" +) + +func setup(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + // addresses will be consolidated over all BIND directives available in that BlocServer + all := []string{} + ifaces, err := net.Interfaces() + if err != nil { + log.Warning(plugin.Error("bind", fmt.Errorf("failed to get interfaces list, cannot bind by interface name: %s", err))) + } + + for c.Next() { + b, err := parse(c) + if err != nil { + return plugin.Error("bind", err) + } + + ips, err := listIP(b.addrs, ifaces) + if err != nil { + return plugin.Error("bind", err) + } + + except, err := listIP(b.except, ifaces) + if err != nil { + return plugin.Error("bind", err) + } + + for _, ip := range ips { + if !slices.Contains(except, ip) { + all = append(all, ip) + } + } + } + + config.ListenHosts = all + return nil +} + +func parse(c *caddy.Controller) (*bind, error) { + b := &bind{} + b.addrs = c.RemainingArgs() + if len(b.addrs) == 0 { + return nil, errors.New("at least one address or interface name is expected") + } + for c.NextBlock() { + switch c.Val() { + case "except": + b.except = c.RemainingArgs() + if len(b.except) == 0 { + return nil, errors.New("at least one address or interface must be given to except subdirective") + } + default: + return nil, fmt.Errorf("invalid option %q", c.Val()) + } + } + return b, nil +} + +// listIP returns a list of IP addresses from a list of arguments which can be either IP-Address or Interface-Name. +func listIP(args []string, ifaces []net.Interface) ([]string, error) { + all := []string{} + var isIface bool + for _, a := range args { + isIface = false + for _, iface := range ifaces { + if a == iface.Name { + isIface = true + addrs, err := iface.Addrs() + if err != nil { + return nil, fmt.Errorf("failed to get the IP addresses of the interface: %q", a) + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.To4() != nil || (!ipnet.IP.IsLinkLocalMulticast() && !ipnet.IP.IsLinkLocalUnicast()) { + all = append(all, ipnet.IP.String()) + } + } + } + } + } + if !isIface { + if net.ParseIP(a) == nil { + return nil, fmt.Errorf("not a valid IP address or interface name: %q", a) + } + all = append(all, a) + } + } + return all, nil +} diff --git a/plugin/bind/setup_test.go b/plugin/bind/setup_test.go new file mode 100644 index 0000000..e8c87b8 --- /dev/null +++ b/plugin/bind/setup_test.go @@ -0,0 +1,47 @@ +package bind + +import ( + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestSetup(t *testing.T) { + for i, test := range []struct { + config string + expected []string + failing bool + }{ + {`bind 1.2.3.4`, []string{"1.2.3.4"}, false}, + {`bind`, nil, true}, + {`bind 1.2.3.invalid`, nil, true}, + {`bind 1.2.3.4 ::5`, []string{"1.2.3.4", "::5"}, false}, + {`bind ::1 1.2.3.4 ::5 127.9.9.0`, []string{"::1", "1.2.3.4", "::5", "127.9.9.0"}, false}, + {`bind ::1 1.2.3.4 ::5 127.9.9.0 noone`, nil, true}, + {`bind 1.2.3.4 lo`, []string{"1.2.3.4", "127.0.0.1", "::1"}, false}, + {"bind lo {\nexcept 127.0.0.1\n}\n", []string{"::1"}, false}, + } { + c := caddy.NewTestController("dns", test.config) + err := setup(c) + if err != nil { + if !test.failing { + t.Fatalf("Test %d, expected no errors, but got: %v", i, err) + } + continue + } + if test.failing { + t.Fatalf("Test %d, expected to failed but did not, returned values", i) + } + cfg := dnsserver.GetConfig(c) + if len(cfg.ListenHosts) != len(test.expected) { + t.Errorf("Test %d : expected the config's ListenHosts size to be %d, was %d", i, len(test.expected), len(cfg.ListenHosts)) + continue + } + for i, v := range test.expected { + if got, want := cfg.ListenHosts[i], v; got != want { + t.Errorf("Test %d : expected the config's ListenHost to be %s, was %s", i, want, got) + } + } + } +} diff --git a/plugin/bufsize/README.md b/plugin/bufsize/README.md new file mode 100644 index 0000000..0dc9623 --- /dev/null +++ b/plugin/bufsize/README.md @@ -0,0 +1,43 @@ +# bufsize +## Name +*bufsize* - limits EDNS0 buffer size to prevent IP fragmentation. + +## Description +*bufsize* limits a requester's UDP payload size to within a maximum value. +If a request with an OPT RR has a bufsize greater than the limit, the bufsize +of the request will be reduced. Otherwise the request is unaffected. +It prevents IP fragmentation, mitigating certain DNS vulnerabilities. +It cannot increase UDP size requested by the client, it can be reduced only. +This will only affect queries that have +an OPT RR ([EDNS(0)](https://www.rfc-editor.org/rfc/rfc6891)). + +## Syntax +```txt +bufsize [SIZE] +``` + +**[SIZE]** is an int value for setting the buffer size. +The default value is 1232, and the value must be within 512 - 4096. +Only one argument is acceptable, and it covers both IPv4 and IPv6. + +## Examples +Enable limiting the buffer size of outgoing query to the resolver (172.31.0.10): +```corefile +. { + bufsize 1100 + forward . 172.31.0.10 + log +} +``` + +Enable limiting the buffer size as an authoritative nameserver: +```corefile +. { + bufsize 1220 + file db.example.org + log +} +``` + +## Considerations +- Setting 1232 bytes to bufsize may avoid fragmentation on the majority of networks in use today, but it depends on the MTU of the physical network links. diff --git a/plugin/bufsize/bufsize.go b/plugin/bufsize/bufsize.go new file mode 100644 index 0000000..00556c2 --- /dev/null +++ b/plugin/bufsize/bufsize.go @@ -0,0 +1,27 @@ +// Package bufsize implements a plugin that clamps EDNS0 buffer size preventing packet fragmentation. +package bufsize + +import ( + "context" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +// Bufsize implements bufsize plugin. +type Bufsize struct { + Next plugin.Handler + Size int +} + +// ServeDNS implements the plugin.Handler interface. +func (buf Bufsize) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if option := r.IsEdns0(); option != nil && int(option.UDPSize()) > buf.Size { + option.SetUDPSize(uint16(buf.Size)) + } + return plugin.NextOrFailure(buf.Name(), buf.Next, ctx, w, r) +} + +// Name implements the Handler interface. +func (buf Bufsize) Name() string { return "bufsize" } diff --git a/plugin/bufsize/bufsize_test.go b/plugin/bufsize/bufsize_test.go new file mode 100644 index 0000000..eb267dd --- /dev/null +++ b/plugin/bufsize/bufsize_test.go @@ -0,0 +1,102 @@ +package bufsize + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/plugin/whoami" + + "github.com/miekg/dns" +) + +func TestBufsize(t *testing.T) { + const maxBufSize = 1024 + + setUpWithRequestBufsz := func(bufferSize uint16) (Bufsize, *dns.Msg) { + p := Bufsize{ + Size: maxBufSize, + Next: whoami.Whoami{}, + } + r := new(dns.Msg) + r.SetQuestion(dns.Fqdn("."), dns.TypeA) + r.Question[0].Qclass = dns.ClassINET + if bufferSize > 0 { + r.SetEdns0(bufferSize, false) + } + return p, r + } + + t.Run("Limit response buffer size", func(t *testing.T) { + // GIVEN + // plugin initialized with maximum buffer size + // request has larger buffer size than allowed + p, r := setUpWithRequestBufsz(maxBufSize + 128) + + // WHEN + // request is processed + _, err := p.ServeDNS(context.Background(), &test.ResponseWriter{}, r) + + // THEN + // no error + // OPT RR present + // request buffer size is limited + if err != nil { + t.Errorf("unexpected error %s", err) + } + option := r.IsEdns0() + if option == nil { + t.Errorf("OPT RR not present") + } + if option.UDPSize() != maxBufSize { + t.Errorf("buffer size not limited") + } + }) + + t.Run("Do not increase response buffer size", func(t *testing.T) { + // GIVEN + // plugin initialized with maximum buffer size + // request has smaller buffer size than allowed + const smallerBufferSize = maxBufSize - 128 + p, r := setUpWithRequestBufsz(smallerBufferSize) + + // WHEN + // request is processed + _, err := p.ServeDNS(context.Background(), &test.ResponseWriter{}, r) + + // THEN + // no error + // request buffer size is not expanded + if err != nil { + t.Errorf("unexpected error %s", err) + } + option := r.IsEdns0() + if option == nil { + t.Errorf("OPT RR not present") + } + if option.UDPSize() != smallerBufferSize { + t.Errorf("buffer size should not be increased") + } + }) + + t.Run("Buffer size should not be set", func(t *testing.T) { + // GIVEN + // plugin initialized with maximum buffer size + // request has no EDNS0 option set + p, r := setUpWithRequestBufsz(0) + + // WHEN + // request is processed + _, err := p.ServeDNS(context.Background(), &test.ResponseWriter{}, r) + + // THEN + // no error + // OPT RR is not appended + if err != nil { + t.Errorf("unexpected error %s", err) + } + if r.IsEdns0() != nil { + t.Errorf("EDNS0 enabled for incoming request") + } + }) +} diff --git a/plugin/bufsize/setup.go b/plugin/bufsize/setup.go new file mode 100644 index 0000000..56113e6 --- /dev/null +++ b/plugin/bufsize/setup.go @@ -0,0 +1,52 @@ +package bufsize + +import ( + "strconv" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("bufsize", setup) } + +func setup(c *caddy.Controller) error { + bufsize, err := parse(c) + if err != nil { + return plugin.Error("bufsize", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Bufsize{Next: next, Size: bufsize} + }) + + return nil +} + +func parse(c *caddy.Controller) (int, error) { + // value from http://www.dnsflagday.net/2020/ + const defaultBufSize = 1232 + for c.Next() { + args := c.RemainingArgs() + switch len(args) { + case 0: + // Nothing specified; use defaultBufSize + return defaultBufSize, nil + case 1: + // Specified value is needed to verify + bufsize, err := strconv.Atoi(args[0]) + if err != nil { + return -1, plugin.Error("bufsize", c.ArgErr()) + } + // Follows RFC 6891 + if bufsize < 512 || bufsize > 4096 { + return -1, plugin.Error("bufsize", c.ArgErr()) + } + return bufsize, nil + default: + // Only 1 argument is acceptable + return -1, plugin.Error("bufsize", c.ArgErr()) + } + } + return -1, plugin.Error("bufsize", c.ArgErr()) +} diff --git a/plugin/bufsize/setup_test.go b/plugin/bufsize/setup_test.go new file mode 100644 index 0000000..5bf7b80 --- /dev/null +++ b/plugin/bufsize/setup_test.go @@ -0,0 +1,47 @@ +package bufsize + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupBufsize(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedData int + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + {`bufsize`, false, 1232, ""}, + {`bufsize "1220"`, false, 1220, ""}, + {`bufsize "5000"`, true, -1, "plugin"}, + {`bufsize "512 512"`, true, -1, "plugin"}, + {`bufsize "511"`, true, -1, "plugin"}, + {`bufsize "abc123"`, true, -1, "plugin"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + bufsize, err := parse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Error found for input %s. Error: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + + if !test.shouldErr && bufsize != test.expectedData { + t.Errorf("Test %d: Bufsize not correctly set for input %s. Expected: %d, actual: %d", i, test.input, test.expectedData, bufsize) + } + } +} diff --git a/plugin/cache/README.md b/plugin/cache/README.md new file mode 100644 index 0000000..d516a91 --- /dev/null +++ b/plugin/cache/README.md @@ -0,0 +1,144 @@ +# cache + +## Name + +*cache* - enables a frontend cache. + +## Description + +With *cache* enabled, all records except zone transfers and metadata records will be cached for up to +3600s. Caching is mostly useful in a scenario when fetching data from the backend (upstream, +database, etc.) is expensive. + +*Cache* will pass DNSSEC (DNSSEC OK; DO) options through the plugin for upstream queries. + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ txt +cache [TTL] [ZONES...] +~~~ + +* **TTL** max TTL in seconds. If not specified, the maximum TTL will be used, which is 3600 for + NOERROR responses and 1800 for denial of existence ones. + Setting a TTL of 300: `cache 300` would cache records up to 300 seconds. +* **ZONES** zones it should cache for. If empty, the zones from the configuration block are used. + +Each element in the cache is cached according to its TTL (with **TTL** as the max). +A cache is divided into 256 shards, each holding up to 39 items by default - for a total size +of 256 * 39 = 9984 items. + +If you want more control: + +~~~ txt +cache [TTL] [ZONES...] { + success CAPACITY [TTL] [MINTTL] + denial CAPACITY [TTL] [MINTTL] + prefetch AMOUNT [[DURATION] [PERCENTAGE%]] + serve_stale [DURATION] [REFRESH_MODE] + servfail DURATION + disable success|denial [ZONES...] + keepttl +} +~~~ + +* **TTL** and **ZONES** as above. +* `success`, override the settings for caching successful responses. **CAPACITY** indicates the maximum + number of packets we cache before we start evicting (*randomly*). **TTL** overrides the cache maximum TTL. + **MINTTL** overrides the cache minimum TTL (default 5), which can be useful to limit queries to the backend. +* `denial`, override the settings for caching denial of existence responses. **CAPACITY** indicates the maximum + number of packets we cache before we start evicting (LRU). **TTL** overrides the cache maximum TTL. + **MINTTL** overrides the cache minimum TTL (default 5), which can be useful to limit queries to the backend. + There is a third category (`error`) but those responses are never cached. +* `prefetch` will prefetch popular items when they are about to be expunged from the cache. + Popular means **AMOUNT** queries have been seen with no gaps of **DURATION** or more between them. + **DURATION** defaults to 1m. Prefetching will happen when the TTL drops below **PERCENTAGE**, + which defaults to `10%`, or latest 1 second before TTL expiration. Values should be in the range `[10%, 90%]`. + Note the percent sign is mandatory. **PERCENTAGE** is treated as an `int`. +* `serve_stale`, when serve\_stale is set, cache will always serve an expired entry to a client if there is one + available as long as it has not been expired for longer than **DURATION** (default 1 hour). By default, the _cache_ plugin will + attempt to refresh the cache entry after sending the expired cache entry to the client. The + responses have a TTL of 0. **REFRESH_MODE** controls the timing of the expired cache entry refresh. + `verify` will first verify that an entry is still unavailable from the source before sending the expired entry to the client. + `immediate` will immediately send the expired entry to the client before + checking to see if the entry is available from the source. **REFRESH_MODE** defaults to `immediate`. Setting this + value to `verify` can lead to increased latency when serving stale responses, but will prevent stale entries + from ever being served if an updated response can be retrieved from the source. +* `servfail` cache SERVFAIL responses for **DURATION**. Setting **DURATION** to 0 will disable caching of SERVFAIL + responses. If this option is not set, SERVFAIL responses will be cached for 5 seconds. **DURATION** may not be + greater than 5 minutes. +* `disable` disable the success or denial cache for the listed **ZONES**. If no **ZONES** are given, the specified + cache will be disabled for all zones. +* `keepttl` do not age TTL when serving responses from cache. The entry will still be removed from cache + when the TTL expires as normal, but until it expires responses will include the original TTL instead + of the remaining TTL. This can be useful if CoreDNS is used as an authoritative server and you want + to serve a consistent TTL to downstream clients. This is **NOT** recommended when CoreDNS is caching + records it is not authoritative for because it could result in downstream clients using stale answers. + +## Capacity and Eviction + +If **CAPACITY** _is not_ specified, the default cache size is 9984 per cache. The minimum allowed cache size is 1024. +If **CAPACITY** _is_ specified, the actual cache size used will be rounded down to the nearest number divisible by 256 (so all shards are equal in size). + +Eviction is done per shard. In effect, when a shard reaches capacity, items are evicted from that shard. +Since shards don't fill up perfectly evenly, evictions will occur before the entire cache reaches full capacity. +Each shard capacity is equal to the total cache size / number of shards (256). Eviction is random, not TTL based. +Entries with 0 TTL will remain in the cache until randomly evicted when the shard reaches capacity. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metrics are exported: + +* `coredns_cache_entries{server, type, zones, view}` - Total elements in the cache by cache type. +* `coredns_cache_hits_total{server, type, zones, view}` - Counter of cache hits by cache type. +* `coredns_cache_misses_total{server, zones, view}` - Counter of cache misses. - Deprecated, derive misses from cache hits/requests counters. +* `coredns_cache_requests_total{server, zones, view}` - Counter of cache requests. +* `coredns_cache_prefetch_total{server, zones, view}` - Counter of times the cache has prefetched a cached item. +* `coredns_cache_drops_total{server, zones, view}` - Counter of responses excluded from the cache due to request/response question name mismatch. +* `coredns_cache_served_stale_total{server, zones, view}` - Counter of requests served from stale cache entries. +* `coredns_cache_evictions_total{server, type, zones, view}` - Counter of cache evictions. + +Cache types are either "denial" or "success". `Server` is the server handling the request, see the +prometheus plugin for documentation. + +## Examples + +Enable caching for all zones, but cap everything to a TTL of 10 seconds: + +~~~ corefile +. { + cache 10 + whoami +} +~~~ + +Proxy to Google Public DNS and only cache responses for example.org (or below). + +~~~ corefile +. { + forward . 8.8.8.8:53 + cache example.org +} +~~~ + +Enable caching for `example.org`, keep a positive cache size of 5000 and a negative cache size of 2500: + +~~~ corefile +example.org { + cache { + success 5000 + denial 2500 + } +} +~~~ + +Enable caching for `example.org`, but do not cache denials in `sub.example.org`: + +~~~ corefile +example.org { + cache { + disable denial sub.example.org + } +} +~~~ diff --git a/plugin/cache/cache.go b/plugin/cache/cache.go new file mode 100644 index 0000000..1378263 --- /dev/null +++ b/plugin/cache/cache.go @@ -0,0 +1,320 @@ +// Package cache implements a cache. +package cache + +import ( + "hash/fnv" + "net" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Cache is a plugin that looks up responses in a cache and caches replies. +// It has a success and a denial of existence cache. +type Cache struct { + Next plugin.Handler + Zones []string + + zonesMetricLabel string + viewMetricLabel string + + ncache *cache.Cache + ncap int + nttl time.Duration + minnttl time.Duration + + pcache *cache.Cache + pcap int + pttl time.Duration + minpttl time.Duration + failttl time.Duration // TTL for caching SERVFAIL responses + + // Prefetch. + prefetch int + duration time.Duration + percentage int + + // Stale serve + staleUpTo time.Duration + verifyStale bool + + // Positive/negative zone exceptions + pexcept []string + nexcept []string + + // Keep ttl option + keepttl bool + + // Testing. + now func() time.Time +} + +// New returns an initialized Cache with default settings. It's up to the +// caller to set the Next handler. +func New() *Cache { + return &Cache{ + Zones: []string{"."}, + pcap: defaultCap, + pcache: cache.New(defaultCap), + pttl: maxTTL, + minpttl: minTTL, + ncap: defaultCap, + ncache: cache.New(defaultCap), + nttl: maxNTTL, + minnttl: minNTTL, + failttl: minNTTL, + prefetch: 0, + duration: 1 * time.Minute, + percentage: 10, + now: time.Now, + } +} + +// key returns key under which we store the item, -1 will be returned if we don't store the message. +// Currently we do not cache Truncated, errors zone transfers or dynamic update messages. +// qname holds the already lowercased qname. +func key(qname string, m *dns.Msg, t response.Type, do, cd bool) (bool, uint64) { + // We don't store truncated responses. + if m.Truncated { + return false, 0 + } + // Nor errors or Meta or Update. + if t == response.OtherError || t == response.Meta || t == response.Update { + return false, 0 + } + + return true, hash(qname, m.Question[0].Qtype, do, cd) +} + +var one = []byte("1") +var zero = []byte("0") + +func hash(qname string, qtype uint16, do, cd bool) uint64 { + h := fnv.New64() + + if do { + h.Write(one) + } else { + h.Write(zero) + } + + if cd { + h.Write(one) + } else { + h.Write(zero) + } + + h.Write([]byte{byte(qtype >> 8)}) + h.Write([]byte{byte(qtype)}) + h.Write([]byte(qname)) + return h.Sum64() +} + +func computeTTL(msgTTL, minTTL, maxTTL time.Duration) time.Duration { + ttl := msgTTL + if ttl < minTTL { + ttl = minTTL + } + if ttl > maxTTL { + ttl = maxTTL + } + return ttl +} + +// ResponseWriter is a response writer that caches the reply message. +type ResponseWriter struct { + dns.ResponseWriter + *Cache + state request.Request + server string // Server handling the request. + + do bool // When true the original request had the DO bit set. + cd bool // When true the original request had the CD bit set. + ad bool // When true the original request had the AD bit set. + prefetch bool // When true write nothing back to the client. + remoteAddr net.Addr + + wildcardFunc func() string // function to retrieve wildcard name that synthesized the result. + + pexcept []string // positive zone exceptions + nexcept []string // negative zone exceptions +} + +// newPrefetchResponseWriter returns a Cache ResponseWriter to be used in +// prefetch requests. It ensures RemoteAddr() can be called even after the +// original connection has already been closed. +func newPrefetchResponseWriter(server string, state request.Request, c *Cache) *ResponseWriter { + // Resolve the address now, the connection might be already closed when the + // actual prefetch request is made. + addr := state.W.RemoteAddr() + // The protocol of the client triggering a cache prefetch doesn't matter. + // The address type is used by request.Proto to determine the response size, + // and using TCP ensures the message isn't unnecessarily truncated. + if u, ok := addr.(*net.UDPAddr); ok { + addr = &net.TCPAddr{IP: u.IP, Port: u.Port, Zone: u.Zone} + } + + return &ResponseWriter{ + ResponseWriter: state.W, + Cache: c, + state: state, + server: server, + do: state.Do(), + cd: state.Req.CheckingDisabled, + prefetch: true, + remoteAddr: addr, + } +} + +// RemoteAddr implements the dns.ResponseWriter interface. +func (w *ResponseWriter) RemoteAddr() net.Addr { + if w.remoteAddr != nil { + return w.remoteAddr + } + return w.ResponseWriter.RemoteAddr() +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { + mt, _ := response.Typify(res, w.now().UTC()) + + // key returns empty string for anything we don't want to cache. + hasKey, key := key(w.state.Name(), res, mt, w.do, w.cd) + + msgTTL := dnsutil.MinimalTTL(res, mt) + var duration time.Duration + if mt == response.NameError || mt == response.NoData { + duration = computeTTL(msgTTL, w.minnttl, w.nttl) + } else if mt == response.ServerError { + duration = w.failttl + } else { + duration = computeTTL(msgTTL, w.minpttl, w.pttl) + } + + if hasKey && duration > 0 { + if w.state.Match(res) { + w.set(res, key, mt, duration) + cacheSize.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.pcache.Len())) + cacheSize.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Set(float64(w.ncache.Len())) + } else { + // Don't log it, but increment counter + cacheDrops.WithLabelValues(w.server, w.zonesMetricLabel, w.viewMetricLabel).Inc() + } + } + + if w.prefetch { + return nil + } + + // Apply capped TTL to this reply to avoid jarring TTL experience 1799 -> 8 (e.g.) + ttl := uint32(duration.Seconds()) + res.Answer = filterRRSlice(res.Answer, ttl, false) + res.Ns = filterRRSlice(res.Ns, ttl, false) + res.Extra = filterRRSlice(res.Extra, ttl, false) + + if !w.do && !w.ad { + // unset AD bit if requester is not OK with DNSSEC + // But retain AD bit if requester set the AD bit in the request, per RFC6840 5.7-5.8 + res.AuthenticatedData = false + } + + return w.ResponseWriter.WriteMsg(res) +} + +func (w *ResponseWriter) set(m *dns.Msg, key uint64, mt response.Type, duration time.Duration) { + // duration is expected > 0 + // and key is valid + switch mt { + case response.NoError, response.Delegation: + if plugin.Zones(w.pexcept).Matches(m.Question[0].Name) != "" { + // zone is in exception list, do not cache + return + } + i := newItem(m, w.now(), duration) + if w.wildcardFunc != nil { + i.wildcard = w.wildcardFunc() + } + if w.pcache.Add(key, i) { + evictions.WithLabelValues(w.server, Success, w.zonesMetricLabel, w.viewMetricLabel).Inc() + } + // when pre-fetching, remove the negative cache entry if it exists + if w.prefetch { + w.ncache.Remove(key) + } + + case response.NameError, response.NoData, response.ServerError: + if plugin.Zones(w.nexcept).Matches(m.Question[0].Name) != "" { + // zone is in exception list, do not cache + return + } + i := newItem(m, w.now(), duration) + if w.wildcardFunc != nil { + i.wildcard = w.wildcardFunc() + } + if w.ncache.Add(key, i) { + evictions.WithLabelValues(w.server, Denial, w.zonesMetricLabel, w.viewMetricLabel).Inc() + } + + case response.OtherError: + // don't cache these + default: + log.Warningf("Caching called with unknown classification: %d", mt) + } +} + +// Write implements the dns.ResponseWriter interface. +func (w *ResponseWriter) Write(buf []byte) (int, error) { + log.Warning("Caching called with Write: not caching reply") + if w.prefetch { + return 0, nil + } + n, err := w.ResponseWriter.Write(buf) + return n, err +} + +// verifyStaleResponseWriter is a response writer that only writes messages if they should replace a +// stale cache entry, and otherwise discards them. +type verifyStaleResponseWriter struct { + *ResponseWriter + refreshed bool // set to true if the last WriteMsg wrote to ResponseWriter, false otherwise. +} + +// newVerifyStaleResponseWriter returns a ResponseWriter to be used when verifying stale cache +// entries. It only forward writes if an entry was successfully refreshed according to RFC8767, +// section 4 (response is NoError or NXDomain), and ignores any other response. +func newVerifyStaleResponseWriter(w *ResponseWriter) *verifyStaleResponseWriter { + return &verifyStaleResponseWriter{ + w, + false, + } +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (w *verifyStaleResponseWriter) WriteMsg(res *dns.Msg) error { + w.refreshed = false + if res.Rcode == dns.RcodeSuccess || res.Rcode == dns.RcodeNameError { + w.refreshed = true + return w.ResponseWriter.WriteMsg(res) // stores to the cache and send to client + } + return nil // else discard +} + +const ( + maxTTL = dnsutil.MaximumDefaulTTL + minTTL = dnsutil.MinimalDefaultTTL + maxNTTL = dnsutil.MaximumDefaulTTL / 2 + minNTTL = dnsutil.MinimalDefaultTTL + + defaultCap = 10000 // default capacity of the cache. + + // Success is the class for caching positive caching. + Success = "success" + // Denial is the class defined for negative caching. + Denial = "denial" +) diff --git a/plugin/cache/cache_test.go b/plugin/cache/cache_test.go new file mode 100644 index 0000000..947c675 --- /dev/null +++ b/plugin/cache/cache_test.go @@ -0,0 +1,901 @@ +package cache + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func cacheMsg(m *dns.Msg, tc test.Case) *dns.Msg { + m.RecursionAvailable = tc.RecursionAvailable + m.AuthenticatedData = tc.AuthenticatedData + m.CheckingDisabled = tc.CheckingDisabled + m.Authoritative = tc.Authoritative + m.Rcode = tc.Rcode + m.Truncated = tc.Truncated + m.Answer = tc.Answer + m.Ns = tc.Ns + // m.Extra = tc.in.Extra don't copy Extra, because we don't care and fake EDNS0 DO with tc.Do. + return m +} + +func newTestCache(ttl time.Duration) (*Cache, *ResponseWriter) { + c := New() + c.pttl = ttl + c.nttl = ttl + + crr := &ResponseWriter{ResponseWriter: nil, Cache: c} + crr.nexcept = []string{"neg-disabled.example.org."} + crr.pexcept = []string{"pos-disabled.example.org."} + + return c, crr +} + +// TestCacheInsertion verifies the insertion of items to the cache. +func TestCacheInsertion(t *testing.T) { + cacheTestCases := []struct { + name string + out test.Case // the expected message coming "out" of cache + in test.Case // the test message going "in" to cache + shouldCache bool + }{ + { + name: "test ad bit cache", + out: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3600 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + in: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3601 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3601 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + shouldCache: true, + }, + { + name: "test case sensitivity cache", + out: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3600 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + in: test.Case{ + Qname: "mIEK.nL.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 3601 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3601 IN MX 10 aspmx2.googlemail.com."), + }, + RecursionAvailable: true, + AuthenticatedData: true, + }, + shouldCache: true, + }, + { + name: "test truncated responses shouldn't cache", + in: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com.")}, + Truncated: true, + }, + shouldCache: false, + }, + { + name: "test dns.RcodeNameError cache", + out: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + RecursionAvailable: true, + }, + in: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test dns.RcodeServerFailure cache", + out: test.Case{ + Rcode: dns.RcodeServerFailure, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + in: test.Case{ + Rcode: dns.RcodeServerFailure, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test dns.RcodeNotImplemented cache", + out: test.Case{ + Rcode: dns.RcodeNotImplemented, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + in: test.Case{ + Rcode: dns.RcodeNotImplemented, + Qname: "example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{}, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test expired RRSIG doesn't cache", + in: test.Case{ + Qname: "miek.nl.", Qtype: dns.TypeMX, + Do: true, + Answer: []dns.RR{ + test.MX("miek.nl. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 3600 IN MX 10 aspmx2.googlemail.com."), + test.RRSIG("miek.nl. 1800 IN RRSIG MX 8 2 1800 20160521031301 20160421031301 12051 miek.nl. lAaEzB5teQLLKyDenatmyhca7blLRg9DoGNrhe3NReBZN5C5/pMQk8Jc u25hv2fW23/SLm5IC2zaDpp2Fzgm6Jf7e90/yLcwQPuE7JjS55WMF+HE LEh7Z6AEb+Iq4BWmNhUz6gPxD4d9eRMs7EAzk13o1NYi5/JhfL6IlaYy qkc="), + }, + RecursionAvailable: true, + }, + shouldCache: false, + }, + { + name: "test DO bit with RRSIG not expired cache", + out: test.Case{ + Qname: "example.org.", Qtype: dns.TypeMX, + Do: true, + Answer: []dns.RR{ + test.MX("example.org. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("example.org. 3600 IN MX 10 aspmx2.googlemail.com."), + test.RRSIG("example.org. 3600 IN RRSIG MX 8 2 1800 20170521031301 20170421031301 12051 miek.nl. lAaEzB5teQLLKyDenatmyhca7blLRg9DoGNrhe3NReBZN5C5/pMQk8Jc u25hv2fW23/SLm5IC2zaDpp2Fzgm6Jf7e90/yLcwQPuE7JjS55WMF+HE LEh7Z6AEb+Iq4BWmNhUz6gPxD4d9eRMs7EAzk13o1NYi5/JhfL6IlaYy qkc="), + }, + RecursionAvailable: true, + }, + in: test.Case{ + Qname: "example.org.", Qtype: dns.TypeMX, + Do: true, + Answer: []dns.RR{ + test.MX("example.org. 3600 IN MX 1 aspmx.l.google.com."), + test.MX("example.org. 3600 IN MX 10 aspmx2.googlemail.com."), + test.RRSIG("example.org. 1800 IN RRSIG MX 8 2 1800 20170521031301 20170421031301 12051 miek.nl. lAaEzB5teQLLKyDenatmyhca7blLRg9DoGNrhe3NReBZN5C5/pMQk8Jc u25hv2fW23/SLm5IC2zaDpp2Fzgm6Jf7e90/yLcwQPuE7JjS55WMF+HE LEh7Z6AEb+Iq4BWmNhUz6gPxD4d9eRMs7EAzk13o1NYi5/JhfL6IlaYy qkc="), + }, + RecursionAvailable: true, + }, + shouldCache: true, + }, + { + name: "test CD bit cache", + out: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "dnssec-failed.org.", + Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("dnssec-failed.org. 3600 IN A 127.0.0.1"), + }, + CheckingDisabled: true, + }, + in: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "dnssec-failed.org.", + Answer: []dns.RR{ + test.A("dnssec-failed.org. 3600 IN A 127.0.0.1"), + }, + Qtype: dns.TypeA, + CheckingDisabled: true, + }, + shouldCache: true, + }, + { + name: "test negative zone exception shouldn't cache", + in: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "neg-disabled.example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + }, + shouldCache: false, + }, + { + name: "test positive zone exception shouldn't cache", + in: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "pos-disabled.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("pos-disabled.example.org. 3600 IN A 127.0.0.1"), + }, + }, + shouldCache: false, + }, + { + name: "test positive zone exception with negative answer cache", + in: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "pos-disabled.example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + }, + out: test.Case{ + Rcode: dns.RcodeNameError, + Qname: "pos-disabled.example.org.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("example.org. 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600"), + }, + }, + shouldCache: true, + }, + { + name: "test negative zone exception with positive answer cache", + in: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "neg-disabled.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("neg-disabled.example.org. 3600 IN A 127.0.0.1"), + }, + }, + out: test.Case{ + Rcode: dns.RcodeSuccess, + Qname: "neg-disabled.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("neg-disabled.example.org. 3600 IN A 127.0.0.1"), + }, + }, + shouldCache: true, + }, + } + now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017") + utc := now.UTC() + + for _, tc := range cacheTestCases { + t.Run(tc.name, func(t *testing.T) { + // Create a new cache every time to prevent accidental comparison with a previous item. + c, crr := newTestCache(maxTTL) + + m := tc.in.Msg() + m = cacheMsg(m, tc.in) + + state := request.Request{W: &test.ResponseWriter{}, Req: m} + + mt, _ := response.Typify(m, utc) + valid, k := key(state.Name(), m, mt, state.Do(), state.Req.CheckingDisabled) + + if valid { + // Insert cache entry + crr.set(m, k, mt, c.pttl) + } + + // Attempt to retrieve cache entry + i := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53") + found := i != nil + + if !tc.shouldCache && found { + t.Fatalf("Cached message that should not have been cached: %s", state.Name()) + } + if tc.shouldCache && !found { + t.Fatalf("Did not cache message that should have been cached: %s", state.Name()) + } + + if found { + resp := i.toMsg(m, time.Now().UTC(), state.Do(), m.AuthenticatedData) + + // TODO: If we incorporate these individual checks into the + // test.Header function, we can eliminate them from here. + // Cache entries are always Authoritative. + if resp.Authoritative != true { + t.Error("Expected Authoritative Answer bit to be true, but was false") + } + if resp.AuthenticatedData != tc.out.AuthenticatedData { + t.Errorf("Expected Authenticated Data bit to be %t, but got %t", tc.out.AuthenticatedData, resp.AuthenticatedData) + } + if resp.RecursionAvailable != tc.out.RecursionAvailable { + t.Errorf("Expected Recursion Available bit to be %t, but got %t", tc.out.RecursionAvailable, resp.RecursionAvailable) + } + if resp.CheckingDisabled != tc.out.CheckingDisabled { + t.Errorf("Expected Checking Disabled bit to be %t, but got %t", tc.out.CheckingDisabled, resp.CheckingDisabled) + } + + if err := test.Header(tc.out, resp); err != nil { + t.Logf("Cache %v", resp) + t.Error(err) + } + if err := test.Section(tc.out, test.Answer, resp.Answer); err != nil { + t.Logf("Cache %v -- %v", test.Answer, resp.Answer) + t.Error(err) + } + if err := test.Section(tc.out, test.Ns, resp.Ns); err != nil { + t.Error(err) + } + if err := test.Section(tc.out, test.Extra, resp.Extra); err != nil { + t.Error(err) + } + } + }) + } +} + +func TestCacheZeroTTL(t *testing.T) { + c := New() + c.minpttl = 0 + c.minnttl = 0 + c.Next = ttlBackend(0) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + ctx := context.TODO() + + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + if c.pcache.Len() != 0 { + t.Errorf("Msg with 0 TTL should not have been cached") + } + if c.ncache.Len() != 0 { + t.Errorf("Msg with 0 TTL should not have been cached") + } +} + +func TestCacheServfailTTL0(t *testing.T) { + c := New() + c.minpttl = minTTL + c.minnttl = minNTTL + c.failttl = 0 + c.Next = servFailBackend(0) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + ctx := context.TODO() + + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + if c.ncache.Len() != 0 { + t.Errorf("SERVFAIL response should not have been cached") + } +} + +func TestServeFromStaleCache(t *testing.T) { + c := New() + c.Next = ttlBackend(60) + + req := new(dns.Msg) + req.SetQuestion("cached.org.", dns.TypeA) + ctx := context.TODO() + + // Cache cached.org. with 60s TTL + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.staleUpTo = 1 * time.Hour + c.ServeDNS(ctx, rec, req) + if c.pcache.Len() != 1 { + t.Fatalf("Msg with > 0 TTL should have been cached") + } + + // No more backend resolutions, just from cache if available. + c.Next = plugin.HandlerFunc(func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 255, nil // Below, a 255 means we tried querying upstream. + }) + + tests := []struct { + name string + futureMinutes int + expectedResult int + }{ + {"cached.org.", 30, 0}, + {"cached.org.", 60, 0}, + {"cached.org.", 70, 255}, + + {"notcached.org.", 30, 255}, + {"notcached.org.", 60, 255}, + {"notcached.org.", 70, 255}, + } + + for i, tt := range tests { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.now = func() time.Time { return time.Now().Add(time.Duration(tt.futureMinutes) * time.Minute) } + r := req.Copy() + r.SetQuestion(tt.name, dns.TypeA) + if ret, _ := c.ServeDNS(ctx, rec, r); ret != tt.expectedResult { + t.Errorf("Test %d: expecting %v; got %v", i, tt.expectedResult, ret) + } + } +} + +func TestServeFromStaleCacheFetchVerify(t *testing.T) { + c := New() + c.Next = ttlBackend(120) + + req := new(dns.Msg) + req.SetQuestion("cached.org.", dns.TypeA) + ctx := context.TODO() + + // Cache cached.org. with 120s TTL + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.staleUpTo = 1 * time.Hour + c.verifyStale = true + c.ServeDNS(ctx, rec, req) + if c.pcache.Len() != 1 { + t.Fatalf("Msg with > 0 TTL should have been cached") + } + + tests := []struct { + name string + upstreamRCode int + upstreamTtl int + futureMinutes int + expectedRCode int + expectedTtl int + }{ + // After 1 minutes of initial TTL, we should see a cached response + {"cached.org.", dns.RcodeSuccess, 200, 1, dns.RcodeSuccess, 60}, // ttl = 120 - 60 -- not refreshed + + // After the 2 more minutes, we should see upstream responses because upstream is available + {"cached.org.", dns.RcodeSuccess, 200, 3, dns.RcodeSuccess, 200}, + + // After the TTL expired, if the server fails we should get the cached entry + {"cached.org.", dns.RcodeServerFailure, 200, 7, dns.RcodeSuccess, 0}, + + // After 1 more minutes, if the server serves nxdomain we should see them (despite being within the serve stale period) + {"cached.org.", dns.RcodeNameError, 150, 8, dns.RcodeNameError, 150}, + } + + for i, tt := range tests { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.now = func() time.Time { return time.Now().Add(time.Duration(tt.futureMinutes) * time.Minute) } + + if tt.upstreamRCode == dns.RcodeSuccess { + c.Next = ttlBackend(tt.upstreamTtl) + } else if tt.upstreamRCode == dns.RcodeServerFailure { + // Make upstream fail, should now rely on cache during the c.staleUpTo period + c.Next = servFailBackend(tt.upstreamTtl) + } else if tt.upstreamRCode == dns.RcodeNameError { + c.Next = nxDomainBackend(tt.upstreamTtl) + } else { + t.Fatal("upstream code not implemented") + } + + r := req.Copy() + r.SetQuestion(tt.name, dns.TypeA) + ret, _ := c.ServeDNS(ctx, rec, r) + if ret != tt.expectedRCode { + t.Errorf("Test %d: expected rcode=%v, got rcode=%v", i, tt.expectedRCode, ret) + continue + } + if ret == dns.RcodeSuccess { + recTtl := rec.Msg.Answer[0].Header().Ttl + if tt.expectedTtl != int(recTtl) { + t.Errorf("Test %d: expected TTL=%d, got TTL=%d", i, tt.expectedTtl, recTtl) + } + } else if ret == dns.RcodeNameError { + soaTtl := rec.Msg.Ns[0].Header().Ttl + if tt.expectedTtl != int(soaTtl) { + t.Errorf("Test %d: expected TTL=%d, got TTL=%d", i, tt.expectedTtl, soaTtl) + } + } + } +} + +func TestNegativeStaleMaskingPositiveCache(t *testing.T) { + c := New() + c.staleUpTo = time.Minute * 10 + c.Next = nxDomainBackend(60) + + req := new(dns.Msg) + qname := "cached.org." + req.SetQuestion(qname, dns.TypeA) + ctx := context.TODO() + + // Add an entry to Negative Cache": cached.org. = NXDOMAIN + expectedResult := dns.RcodeNameError + if ret, _ := c.ServeDNS(ctx, &test.ResponseWriter{}, req); ret != expectedResult { + t.Errorf("Test 0 Negative Cache Population: expecting %v; got %v", expectedResult, ret) + } + + // Confirm item was added to negative cache and not to positive cache + if c.ncache.Len() == 0 { + t.Errorf("Test 0 Negative Cache Population: item not added to negative cache") + } + if c.pcache.Len() != 0 { + t.Errorf("Test 0 Negative Cache Population: item added to positive cache") + } + + // Set the Backend to return non-cachable errors only + c.Next = plugin.HandlerFunc(func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 255, nil // Below, a 255 means we tried querying upstream. + }) + + // Confirm we get the NXDOMAIN from the negative cache, not the error form the backend + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req = new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + expectedResult = dns.RcodeNameError + if c.ServeDNS(ctx, rec, req); rec.Rcode != expectedResult { + t.Errorf("Test 1 NXDOMAIN from Negative Cache: expecting %v; got %v", expectedResult, rec.Rcode) + } + + // Jump into the future beyond when the negative cache item would go stale + // but before the item goes rotten (exceeds serve stale time) + c.now = func() time.Time { return time.Now().Add(time.Duration(5) * time.Minute) } + + // Set Backend to return a positive NOERROR + A record response + c.Next = BackendHandler() + + // Make a query for the stale cache item + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + req = new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + expectedResult = dns.RcodeNameError + if c.ServeDNS(ctx, rec, req); rec.Rcode != expectedResult { + t.Errorf("Test 2 NOERROR from Backend: expecting %v; got %v", expectedResult, rec.Rcode) + } + + // Confirm that prefetch removes the negative cache item. + waitFor := 3 + for i := 1; i <= waitFor; i++ { + if c.ncache.Len() != 0 { + if i == waitFor { + t.Errorf("Test 2 NOERROR from Backend: item still exists in negative cache") + } + time.Sleep(time.Second) + continue + } + } + + // Confirm that positive cache has the item + if c.pcache.Len() != 1 { + t.Errorf("Test 2 NOERROR from Backend: item missing from positive cache") + } + + // Backend - Give error only + c.Next = plugin.HandlerFunc(func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) { + return 255, nil // Below, a 255 means we tried querying upstream. + }) + + // Query again, expect that positive cache entry is not masked by a negative cache entry + rec = dnstest.NewRecorder(&test.ResponseWriter{}) + req = new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + expectedResult = dns.RcodeSuccess + if ret, _ := c.ServeDNS(ctx, rec, req); ret != expectedResult { + t.Errorf("Test 3 NOERROR from Cache: expecting %v; got %v", expectedResult, ret) + } +} + +func BenchmarkCacheResponse(b *testing.B) { + c := New() + c.prefetch = 1 + c.Next = BackendHandler() + + ctx := context.TODO() + + reqs := make([]*dns.Msg, 5) + for i, q := range []string{"example1", "example2", "a", "b", "ddd"} { + reqs[i] = new(dns.Msg) + reqs[i].SetQuestion(q+".example.org.", dns.TypeA) + } + + b.StartTimer() + + j := 0 + for i := 0; i < b.N; i++ { + req := reqs[j] + c.ServeDNS(ctx, &test.ResponseWriter{}, req) + j = (j + 1) % 5 + } +} + +func BackendHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response = true + m.RecursionAvailable = true + + owner := m.Question[0].Name + m.Answer = []dns.RR{test.A(owner + " 303 IN A 127.0.0.53")} + + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func nxDomainBackend(ttl int) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Ns = []dns.RR{test.SOA(fmt.Sprintf("example.org. %d IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600", ttl))} + + m.MsgHdr.Rcode = dns.RcodeNameError + w.WriteMsg(m) + return dns.RcodeNameError, nil + }) +} + +func ttlBackend(ttl int) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Answer = []dns.RR{test.A(fmt.Sprintf("example.org. %d IN A 127.0.0.53", ttl))} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func servFailBackend(ttl int) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + + m.Ns = []dns.RR{test.SOA(fmt.Sprintf("example.org. %d IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082540 7200 3600 1209600 3600", ttl))} + + m.MsgHdr.Rcode = dns.RcodeServerFailure + w.WriteMsg(m) + return dns.RcodeServerFailure, nil + }) +} + +func TestComputeTTL(t *testing.T) { + tests := []struct { + msgTTL time.Duration + minTTL time.Duration + maxTTL time.Duration + expectedTTL time.Duration + }{ + {1800 * time.Second, 300 * time.Second, 3600 * time.Second, 1800 * time.Second}, + {299 * time.Second, 300 * time.Second, 3600 * time.Second, 300 * time.Second}, + {299 * time.Second, 0 * time.Second, 3600 * time.Second, 299 * time.Second}, + {3601 * time.Second, 300 * time.Second, 3600 * time.Second, 3600 * time.Second}, + } + for i, test := range tests { + ttl := computeTTL(test.msgTTL, test.minTTL, test.maxTTL) + if ttl != test.expectedTTL { + t.Errorf("Test %v: Expected ttl %v but found: %v", i, test.expectedTTL, ttl) + } + } +} + +func TestCacheWildcardMetadata(t *testing.T) { + c := New() + qname := "foo.bar.example.org." + wildcard := "*.bar.example.org." + c.Next = wildcardMetadataBackend(qname, wildcard) + + req := new(dns.Msg) + req.SetQuestion(qname, dns.TypeA) + state := request.Request{W: &test.ResponseWriter{}, Req: req} + + // 1. Test writing wildcard metadata retrieved from backend to the cache + + ctx := metadata.ContextWithMetadata(context.TODO()) + w := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(ctx, w, req) + if c.pcache.Len() != 1 { + t.Errorf("Msg should have been cached") + } + _, k := key(qname, w.Msg, response.NoError, state.Do(), state.Req.CheckingDisabled) + i, _ := c.pcache.Get(k) + if i.(*item).wildcard != wildcard { + t.Errorf("expected wildcard response to enter cache with cache item's wildcard = %q, got %q", wildcard, i.(*item).wildcard) + } + + // 2. Test retrieving the cached item from cache and writing its wildcard value to metadata + + // reset context and response writer + ctx = metadata.ContextWithMetadata(context.TODO()) + w = dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(ctx, w, req) + f := metadata.ValueFunc(ctx, "zone/wildcard") + if f == nil { + t.Fatal("expected metadata func for wildcard response retrieved from cache, got nil") + } + if f() != wildcard { + t.Errorf("after retrieving wildcard item from cache, expected \"zone/wildcard\" metadata value to be %q, got %q", wildcard, i.(*item).wildcard) + } +} + +func TestCacheKeepTTL(t *testing.T) { + defaultTtl := 60 + + c := New() + c.Next = ttlBackend(defaultTtl) + + req := new(dns.Msg) + req.SetQuestion("cached.org.", dns.TypeA) + ctx := context.TODO() + + // Cache cached.org. with 60s TTL + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.keepttl = true + c.ServeDNS(ctx, rec, req) + + tests := []struct { + name string + futureSeconds int + }{ + {"cached.org.", 0}, + {"cached.org.", 30}, + {"uncached.org.", 60}, + } + + for i, tt := range tests { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.now = func() time.Time { return time.Now().Add(time.Duration(tt.futureSeconds) * time.Second) } + r := req.Copy() + r.SetQuestion(tt.name, dns.TypeA) + c.ServeDNS(ctx, rec, r) + + recTtl := rec.Msg.Answer[0].Header().Ttl + if defaultTtl != int(recTtl) { + t.Errorf("Test %d: expecting TTL=%d, got TTL=%d", i, defaultTtl, recTtl) + } + } +} + +// TestCacheSeparation verifies whether the cache maintains separation for specific DNS query types and options. +func TestCacheSeparation(t *testing.T) { + now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017") + utc := now.UTC() + + testCases := []struct { + name string + initial test.Case + query test.Case + expectCached bool // if a cache entry should be found before inserting + }{ + { + name: "query type should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeAAAA, + }, + }, + { + name: "DO bit should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + Do: true, + }, + }, + { + name: "CD bit should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + CheckingDisabled: true, + }, + }, + { + name: "CD bit and DO bit should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + CheckingDisabled: true, + Do: true, + }, + }, + { + name: "CD bit, DO bit, and query type should be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeMX, + CheckingDisabled: true, + Do: true, + }, + }, + { + name: "authoritative answer bit should NOT be unique", + initial: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + }, + query: test.Case{ + Qname: "example.org.", + Qtype: dns.TypeA, + Authoritative: true, + }, + expectCached: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := New() + crr := &ResponseWriter{ResponseWriter: nil, Cache: c} + + // Insert initial cache entry + m := tc.initial.Msg() + m = cacheMsg(m, tc.initial) + state := request.Request{W: &test.ResponseWriter{}, Req: m} + + mt, _ := response.Typify(m, utc) + valid, k := key(state.Name(), m, mt, state.Do(), state.Req.CheckingDisabled) + + if valid { + // Insert cache entry + crr.set(m, k, mt, c.pttl) + } + + // Attempt to retrieve cache entry + m = tc.query.Msg() + m = cacheMsg(m, tc.query) + state = request.Request{W: &test.ResponseWriter{}, Req: m} + + item := c.getIgnoreTTL(time.Now().UTC(), state, "dns://:53") + found := item != nil + + if !tc.expectCached && found { + t.Fatal("Found cache message should that should not exist prior to inserting") + } + if tc.expectCached && !found { + t.Fatal("Did not find cache message that should exist prior to inserting") + } + }) + } +} + +// wildcardMetadataBackend mocks a backend that responds with a response for qname synthesized by wildcard +// and sets the zone/wildcard metadata value +func wildcardMetadataBackend(qname, wildcard string) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetReply(r) + m.Response, m.RecursionAvailable = true, true + m.Answer = []dns.RR{test.A(qname + " 300 IN A 127.0.0.1")} + metadata.SetValueFunc(ctx, "zone/wildcard", func() string { + return wildcard + }) + w.WriteMsg(m) + + return dns.RcodeSuccess, nil + }) +} diff --git a/plugin/cache/dnssec.go b/plugin/cache/dnssec.go new file mode 100644 index 0000000..ec5ff41 --- /dev/null +++ b/plugin/cache/dnssec.go @@ -0,0 +1,24 @@ +package cache + +import "github.com/miekg/dns" + +// filterRRSlice filters out OPT RRs, and sets all RR TTLs to ttl. +// If dup is true the RRs in rrs are _copied_ into the slice that is +// returned. +func filterRRSlice(rrs []dns.RR, ttl uint32, dup bool) []dns.RR { + j := 0 + rs := make([]dns.RR, len(rrs)) + for _, r := range rrs { + if r.Header().Rrtype == dns.TypeOPT { + continue + } + r.Header().Ttl = ttl + if dup { + rs[j] = dns.Copy(r) + } else { + rs[j] = r + } + j++ + } + return rs[:j] +} diff --git a/plugin/cache/dnssec_test.go b/plugin/cache/dnssec_test.go new file mode 100644 index 0000000..b73d52c --- /dev/null +++ b/plugin/cache/dnssec_test.go @@ -0,0 +1,126 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestResponseWithDNSSEC(t *testing.T) { + // We do 2 queries, one where we want non-dnssec and one with dnssec and check the responses in each of them + var tcs = []test.Case{ + { + Qname: "invent.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + }, + }, + { + Qname: "invent.example.org.", Qtype: dns.TypeA, + Do: true, + AuthenticatedData: true, + Answer: []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+"), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9"), + }, + }, + } + + c := New() + c.Next = dnssecHandler() + + for i, tc := range tcs { + m := tc.Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, m) + if tc.AuthenticatedData != rec.Msg.AuthenticatedData { + t.Errorf("Test %d, expected AuthenticatedData=%v", i, tc.AuthenticatedData) + } + if err := test.Section(tc, test.Answer, rec.Msg.Answer); err != nil { + t.Errorf("Test %d, expected no error, got %s", i, err) + } + } + + // now do the reverse + c = New() + c.Next = dnssecHandler() + + for i, tc := range []test.Case{tcs[1], tcs[0]} { + m := tc.Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + c.ServeDNS(context.TODO(), rec, m) + if err := test.Section(tc, test.Answer, rec.Msg.Answer); err != nil { + t.Errorf("Test %d, expected no error, got %s", i, err) + } + } +} + +func dnssecHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + state := request.Request{W: &test.ResponseWriter{}, Req: r} + + m.AuthenticatedData = true + // If query has the DO bit, then send DNSSEC responses (RRSIGs) + if state.Do() { + m.Answer = make([]dns.RR, 4) + m.Answer[0] = test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org.") + m.Answer[1] = test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+") + m.Answer[2] = test.A("leptone.example.org. 1781 IN A 195.201.182.103") + m.Answer[3] = test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9") + } else { + m.Answer = make([]dns.RR, 2) + m.Answer[0] = test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org.") + m.Answer[1] = test.A("leptone.example.org. 1781 IN A 195.201.182.103") + } + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +func TestFilterRRSlice(t *testing.T) { + rrs := []dns.RR{ + test.CNAME("invent.example.org. 1781 IN CNAME leptone.example.org."), + test.RRSIG("invent.example.org. 1781 IN RRSIG CNAME 8 3 1800 20201012085750 20200912082613 57411 example.org. ijSv5FmsNjFviBcOFwQgqjt073lttxTTNqkno6oMa3DD3kC+"), + test.A("leptone.example.org. 1781 IN A 195.201.182.103"), + test.RRSIG("leptone.example.org. 1781 IN RRSIG A 8 3 1800 20201012093630 20200912083827 57411 example.org. eLuSOkLAzm/WIOpaZD3/4TfvKP1HAFzjkis9LIJSRVpQt307dm9WY9"), + } + + filter1 := filterRRSlice(rrs, 0, false) + if len(filter1) != 4 { + t.Errorf("Expected 4 RRs after filtering, got %d", len(filter1)) + } + rrsig := 0 + for _, f := range filter1 { + if f.Header().Rrtype == dns.TypeRRSIG { + rrsig++ + } + } + if rrsig != 2 { + t.Errorf("Expected 2 RRSIGs after filtering, got %d", rrsig) + } + + filter2 := filterRRSlice(rrs, 0, false) + if len(filter2) != 4 { + t.Errorf("Expected 4 RRs after filtering, got %d", len(filter2)) + } + rrsig = 0 + for _, f := range filter2 { + if f.Header().Rrtype == dns.TypeRRSIG { + rrsig++ + } + } + if rrsig != 2 { + t.Errorf("Expected 2 RRSIGs after filtering, got %d", rrsig) + } +} diff --git a/plugin/cache/error_test.go b/plugin/cache/error_test.go new file mode 100644 index 0000000..cd18fda --- /dev/null +++ b/plugin/cache/error_test.go @@ -0,0 +1,38 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestFormErr(t *testing.T) { + c := New() + c.Next = formErrHandler() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(context.TODO(), rec, req) + + if c.pcache.Len() != 0 { + t.Errorf("Cached %s, while reply had %d", "example.org.", rec.Msg.Rcode) + } +} + +// formErrHandler is a fake plugin implementation which returns a FORMERR for a reply. +func formErrHandler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.net.", dns.TypeA) + m.Rcode = dns.RcodeFormatError + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} diff --git a/plugin/cache/freq/freq.go b/plugin/cache/freq/freq.go new file mode 100644 index 0000000..f545f22 --- /dev/null +++ b/plugin/cache/freq/freq.go @@ -0,0 +1,55 @@ +// Package freq keeps track of last X seen events. The events themselves are not stored +// here. So the Freq type should be added next to the thing it is tracking. +package freq + +import ( + "sync" + "time" +) + +// Freq tracks the frequencies of things. +type Freq struct { + // Last time we saw a query for this element. + last time.Time + // Number of this in the last time slice. + hits int + + sync.RWMutex +} + +// New returns a new initialized Freq. +func New(t time.Time) *Freq { + return &Freq{last: t, hits: 0} +} + +// Update updates the number of hits. Last time seen will be set to now. +// If the last time we've seen this entity is within now - d, we increment hits, otherwise +// we reset hits to 1. It returns the number of hits. +func (f *Freq) Update(d time.Duration, now time.Time) int { + earliest := now.Add(-1 * d) + f.Lock() + defer f.Unlock() + if f.last.Before(earliest) { + f.last = now + f.hits = 1 + return f.hits + } + f.last = now + f.hits++ + return f.hits +} + +// Hits returns the number of hits that we have seen, according to the updates we have done to f. +func (f *Freq) Hits() int { + f.RLock() + defer f.RUnlock() + return f.hits +} + +// Reset resets f to time t and hits to hits. +func (f *Freq) Reset(t time.Time, hits int) { + f.Lock() + defer f.Unlock() + f.last = t + f.hits = hits +} diff --git a/plugin/cache/freq/freq_test.go b/plugin/cache/freq/freq_test.go new file mode 100644 index 0000000..740194c --- /dev/null +++ b/plugin/cache/freq/freq_test.go @@ -0,0 +1,36 @@ +package freq + +import ( + "testing" + "time" +) + +func TestFreqUpdate(t *testing.T) { + now := time.Now().UTC() + f := New(now) + window := 1 * time.Minute + + f.Update(window, time.Now().UTC()) + f.Update(window, time.Now().UTC()) + f.Update(window, time.Now().UTC()) + hitsCheck(t, f, 3) + + f.Reset(now, 0) + history := time.Now().UTC().Add(-3 * time.Minute) + f.Update(window, history) + hitsCheck(t, f, 1) +} + +func TestReset(t *testing.T) { + f := New(time.Now().UTC()) + f.Update(1*time.Minute, time.Now().UTC()) + hitsCheck(t, f, 1) + f.Reset(time.Now().UTC(), 0) + hitsCheck(t, f, 0) +} + +func hitsCheck(t *testing.T, f *Freq, expected int) { + if x := f.Hits(); x != expected { + t.Fatalf("Expected hits to be %d, got %d", expected, x) + } +} diff --git a/plugin/cache/fuzz.go b/plugin/cache/fuzz.go new file mode 100644 index 0000000..43f4d26 --- /dev/null +++ b/plugin/cache/fuzz.go @@ -0,0 +1,12 @@ +//go:build gofuzz + +package cache + +import ( + "github.com/coredns/coredns/plugin/pkg/fuzz" +) + +// Fuzz fuzzes cache. +func Fuzz(data []byte) int { + return fuzz.Do(New(), data) +} diff --git a/plugin/cache/handler.go b/plugin/cache/handler.go new file mode 100644 index 0000000..38a8bfe --- /dev/null +++ b/plugin/cache/handler.go @@ -0,0 +1,157 @@ +package cache + +import ( + "context" + "math" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServeDNS implements the plugin.Handler interface. +func (c *Cache) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rc := r.Copy() // We potentially modify r, to prevent other plugins from seeing this (r is a pointer), copy r into rc. + state := request.Request{W: w, Req: rc} + do := state.Do() + cd := r.CheckingDisabled + ad := r.AuthenticatedData + + zone := plugin.Zones(c.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, rc) + } + + now := c.now().UTC() + server := metrics.WithServer(ctx) + + // On cache refresh, we will just use the DO bit from the incoming query for the refresh since we key our cache + // with the query DO bit. That means two separate cache items for the query DO bit true or false. In the situation + // in which upstream doesn't support DNSSEC, the two cache items will effectively be the same. Regardless, any + // DNSSEC RRs in the response are written to cache with the response. + + ttl := 0 + i := c.getIgnoreTTL(now, state, server) + if i == nil { + crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, ad: ad, cd: cd, + nexcept: c.nexcept, pexcept: c.pexcept, wildcardFunc: wildcardFunc(ctx)} + return c.doRefresh(ctx, state, crr) + } + ttl = i.ttl(now) + if ttl < 0 { + // serve stale behavior + if c.verifyStale { + crr := &ResponseWriter{ResponseWriter: w, Cache: c, state: state, server: server, do: do, cd: cd} + cw := newVerifyStaleResponseWriter(crr) + ret, err := c.doRefresh(ctx, state, cw) + if cw.refreshed { + return ret, err + } + } + + // Adjust the time to get a 0 TTL in the reply built from a stale item. + now = now.Add(time.Duration(ttl) * time.Second) + if !c.verifyStale { + cw := newPrefetchResponseWriter(server, state, c) + go c.doPrefetch(ctx, state, cw, i, now) + } + servedStale.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + } else if c.shouldPrefetch(i, now) { + cw := newPrefetchResponseWriter(server, state, c) + go c.doPrefetch(ctx, state, cw, i, now) + } + + if i.wildcard != "" { + // Set wildcard source record name to metadata + metadata.SetValueFunc(ctx, "zone/wildcard", func() string { + return i.wildcard + }) + } + + if c.keepttl { + // If keepttl is enabled we fake the current time to the stored + // one so that we always get the original TTL + now = i.stored + } + resp := i.toMsg(r, now, do, ad) + w.WriteMsg(resp) + return dns.RcodeSuccess, nil +} + +func wildcardFunc(ctx context.Context) func() string { + return func() string { + // Get wildcard source record name from metadata + if f := metadata.ValueFunc(ctx, "zone/wildcard"); f != nil { + return f() + } + return "" + } +} + +func (c *Cache) doPrefetch(ctx context.Context, state request.Request, cw *ResponseWriter, i *item, now time.Time) { + cachePrefetches.WithLabelValues(cw.server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + c.doRefresh(ctx, state, cw) + + // When prefetching we loose the item i, and with it the frequency + // that we've gathered sofar. See we copy the frequencies info back + // into the new item that was stored in the cache. + if i1 := c.exists(state); i1 != nil { + i1.Freq.Reset(now, i.Freq.Hits()) + } +} + +func (c *Cache) doRefresh(ctx context.Context, state request.Request, cw dns.ResponseWriter) (int, error) { + return plugin.NextOrFailure(c.Name(), c.Next, ctx, cw, state.Req) +} + +func (c *Cache) shouldPrefetch(i *item, now time.Time) bool { + if c.prefetch <= 0 { + return false + } + i.Freq.Update(c.duration, now) + threshold := int(math.Ceil(float64(c.percentage) / 100 * float64(i.origTTL))) + return i.Freq.Hits() >= c.prefetch && i.ttl(now) <= threshold +} + +// Name implements the Handler interface. +func (c *Cache) Name() string { return "cache" } + +// getIgnoreTTL unconditionally returns an item if it exists in the cache. +func (c *Cache) getIgnoreTTL(now time.Time, state request.Request, server string) *item { + k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled) + cacheRequests.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + + if i, ok := c.ncache.Get(k); ok { + itm := i.(*item) + ttl := itm.ttl(now) + if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { + cacheHits.WithLabelValues(server, Denial, c.zonesMetricLabel, c.viewMetricLabel).Inc() + return i.(*item) + } + } + if i, ok := c.pcache.Get(k); ok { + itm := i.(*item) + ttl := itm.ttl(now) + if itm.matches(state) && (ttl > 0 || (c.staleUpTo > 0 && -ttl < int(c.staleUpTo.Seconds()))) { + cacheHits.WithLabelValues(server, Success, c.zonesMetricLabel, c.viewMetricLabel).Inc() + return i.(*item) + } + } + cacheMisses.WithLabelValues(server, c.zonesMetricLabel, c.viewMetricLabel).Inc() + return nil +} + +func (c *Cache) exists(state request.Request) *item { + k := hash(state.Name(), state.QType(), state.Do(), state.Req.CheckingDisabled) + if i, ok := c.ncache.Get(k); ok { + return i.(*item) + } + if i, ok := c.pcache.Get(k); ok { + return i.(*item) + } + return nil +} diff --git a/plugin/cache/item.go b/plugin/cache/item.go new file mode 100644 index 0000000..c5aeccd --- /dev/null +++ b/plugin/cache/item.go @@ -0,0 +1,107 @@ +package cache + +import ( + "strings" + "time" + + "github.com/coredns/coredns/plugin/cache/freq" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type item struct { + Name string + QType uint16 + Rcode int + AuthenticatedData bool + RecursionAvailable bool + Answer []dns.RR + Ns []dns.RR + Extra []dns.RR + wildcard string + + origTTL uint32 + stored time.Time + + *freq.Freq +} + +func newItem(m *dns.Msg, now time.Time, d time.Duration) *item { + i := new(item) + if len(m.Question) != 0 { + i.Name = m.Question[0].Name + i.QType = m.Question[0].Qtype + } + i.Rcode = m.Rcode + i.AuthenticatedData = m.AuthenticatedData + i.RecursionAvailable = m.RecursionAvailable + i.Answer = m.Answer + i.Ns = m.Ns + i.Extra = make([]dns.RR, len(m.Extra)) + // Don't copy OPT records as these are hop-by-hop. + j := 0 + for _, e := range m.Extra { + if e.Header().Rrtype == dns.TypeOPT { + continue + } + i.Extra[j] = e + j++ + } + i.Extra = i.Extra[:j] + + i.origTTL = uint32(d.Seconds()) + i.stored = now.UTC() + + i.Freq = new(freq.Freq) + + return i +} + +// toMsg turns i into a message, it tailors the reply to m. +// The Authoritative bit should be set to 0, but some client stub resolver implementations, most notably, +// on some legacy systems(e.g. ubuntu 14.04 with glib version 2.20), low-level glibc function `getaddrinfo` +// useb by Python/Ruby/etc.. will discard answers that do not have this bit set. +// So we're forced to always set this to 1; regardless if the answer came from the cache or not. +// On newer systems(e.g. ubuntu 16.04 with glib version 2.23), this issue is resolved. +// So we may set this bit back to 0 in the future ? +func (i *item) toMsg(m *dns.Msg, now time.Time, do bool, ad bool) *dns.Msg { + m1 := new(dns.Msg) + m1.SetReply(m) + + // Set this to true as some DNS clients discard the *entire* packet when it's non-authoritative. + // This is probably not according to spec, but the bit itself is not super useful as this point, so + // just set it to true. + m1.Authoritative = true + m1.AuthenticatedData = i.AuthenticatedData + if !do && !ad { + // When DNSSEC was not wanted, it can't be authenticated data. + // However, retain the AD bit if the requester set the AD bit, per RFC6840 5.7-5.8 + m1.AuthenticatedData = false + } + m1.RecursionAvailable = i.RecursionAvailable + m1.Rcode = i.Rcode + + m1.Answer = make([]dns.RR, len(i.Answer)) + m1.Ns = make([]dns.RR, len(i.Ns)) + m1.Extra = make([]dns.RR, len(i.Extra)) + + ttl := uint32(i.ttl(now)) + m1.Answer = filterRRSlice(i.Answer, ttl, true) + m1.Ns = filterRRSlice(i.Ns, ttl, true) + m1.Extra = filterRRSlice(i.Extra, ttl, true) + + return m1 +} + +func (i *item) ttl(now time.Time) int { + ttl := int(i.origTTL) - int(now.UTC().Sub(i.stored).Seconds()) + return ttl +} + +func (i *item) matches(state request.Request) bool { + if state.QType() == i.QType && strings.EqualFold(state.QName(), i.Name) { + return true + } + return false +} diff --git a/plugin/cache/log_test.go b/plugin/cache/log_test.go new file mode 100644 index 0000000..220b206 --- /dev/null +++ b/plugin/cache/log_test.go @@ -0,0 +1,5 @@ +package cache + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/cache/metrics.go b/plugin/cache/metrics.go new file mode 100644 index 0000000..77edb02 --- /dev/null +++ b/plugin/cache/metrics.go @@ -0,0 +1,67 @@ +package cache + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // cacheSize is total elements in the cache by cache type. + cacheSize = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "entries", + Help: "The number of elements in the cache.", + }, []string{"server", "type", "zones", "view"}) + // cacheRequests is a counter of all requests through the cache. + cacheRequests = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "requests_total", + Help: "The count of cache requests.", + }, []string{"server", "zones", "view"}) + // cacheHits is counter of cache hits by cache type. + cacheHits = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "hits_total", + Help: "The count of cache hits.", + }, []string{"server", "type", "zones", "view"}) + // cacheMisses is the counter of cache misses. - Deprecated + cacheMisses = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "misses_total", + Help: "The count of cache misses. Deprecated, derive misses from cache hits/requests counters.", + }, []string{"server", "zones", "view"}) + // cachePrefetches is the number of time the cache has prefetched a cached item. + cachePrefetches = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "prefetch_total", + Help: "The number of times the cache has prefetched a cached item.", + }, []string{"server", "zones", "view"}) + // cacheDrops is the number responses that are not cached, because the reply is malformed. + cacheDrops = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "drops_total", + Help: "The number responses that are not cached, because the reply is malformed.", + }, []string{"server", "zones", "view"}) + // servedStale is the number of requests served from stale cache entries. + servedStale = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "served_stale_total", + Help: "The number of requests served from stale cache entries.", + }, []string{"server", "zones", "view"}) + // evictions is the counter of cache evictions. + evictions = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "cache", + Name: "evictions_total", + Help: "The count of cache evictions.", + }, []string{"server", "type", "zones", "view"}) +) diff --git a/plugin/cache/prefetch_test.go b/plugin/cache/prefetch_test.go new file mode 100644 index 0000000..3085fe0 --- /dev/null +++ b/plugin/cache/prefetch_test.go @@ -0,0 +1,228 @@ +package cache + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestPrefetch(t *testing.T) { + tests := []struct { + qname string + ttl int + prefetch int + verifications []verification + }{ + { + qname: "hits.reset.example.org.", + ttl: 80, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "hits.reset.example.org. 80 IN A 127.0.0.1", + fetch: true, // Initial fetch + }, + { + after: 73 * time.Second, + answer: "hits.reset.example.org. 7 IN A 127.0.0.1", + fetch: true, // Triggers prefetch with 7 TTL (10% of 80 = 8 TTL threshold) + }, + { + after: 80 * time.Second, + answer: "hits.reset.example.org. 73 IN A 127.0.0.2", + }, + }, + }, + { + qname: "short.ttl.example.org.", + ttl: 5, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "short.ttl.example.org. 5 IN A 127.0.0.1", + fetch: true, + }, + { + after: 1 * time.Second, + answer: "short.ttl.example.org. 4 IN A 127.0.0.1", + }, + { + after: 4 * time.Second, + answer: "short.ttl.example.org. 1 IN A 127.0.0.1", + fetch: true, + }, + { + after: 5 * time.Second, + answer: "short.ttl.example.org. 4 IN A 127.0.0.2", + }, + }, + }, + { + qname: "no.prefetch.example.org.", + ttl: 30, + prefetch: 0, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "no.prefetch.example.org. 30 IN A 127.0.0.1", + fetch: true, + }, + { + after: 15 * time.Second, + answer: "no.prefetch.example.org. 15 IN A 127.0.0.1", + }, + { + after: 29 * time.Second, + answer: "no.prefetch.example.org. 1 IN A 127.0.0.1", + }, + { + after: 30 * time.Second, + answer: "no.prefetch.example.org. 30 IN A 127.0.0.2", + fetch: true, + }, + }, + }, + { + // tests whether cache prefetches with the do bit + qname: "do.prefetch.example.org.", + ttl: 80, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "do.prefetch.example.org. 80 IN A 127.0.0.1", + do: true, + fetch: true, + }, + { + after: 73 * time.Second, + answer: "do.prefetch.example.org. 7 IN A 127.0.0.1", + do: true, + fetch: true, + }, + { + after: 80 * time.Second, + answer: "do.prefetch.example.org. 73 IN A 127.0.0.2", + do: true, + }, + { + // Should be 127.0.0.3 as 127.0.0.2 was the prefetch WITH do bit + after: 80 * time.Second, + answer: "do.prefetch.example.org. 80 IN A 127.0.0.3", + fetch: true, + }, + }, + }, + { + // tests whether cache prefetches with the cd bit + qname: "cd.prefetch.example.org.", + ttl: 80, + prefetch: 1, + verifications: []verification{ + { + after: 0 * time.Second, + answer: "cd.prefetch.example.org. 80 IN A 127.0.0.1", + cd: true, + fetch: true, + }, + { + after: 73 * time.Second, + answer: "cd.prefetch.example.org. 7 IN A 127.0.0.1", + cd: true, + fetch: true, + }, + { + after: 80 * time.Second, + answer: "cd.prefetch.example.org. 73 IN A 127.0.0.2", + cd: true, + }, + { + // Should be 127.0.0.3 as 127.0.0.2 was the prefetch WITH cd bit + after: 80 * time.Second, + answer: "cd.prefetch.example.org. 80 IN A 127.0.0.3", + fetch: true, + }, + }, + }, + } + + t0, err := time.Parse(time.RFC3339, "2018-01-01T14:00:00+00:00") + if err != nil { + t.Fatal(err) + } + for _, tt := range tests { + t.Run(tt.qname, func(t *testing.T) { + fetchc := make(chan struct{}, 1) + + c := New() + c.Next = prefetchHandler(tt.qname, tt.ttl, fetchc) + c.prefetch = tt.prefetch + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + for _, v := range tt.verifications { + c.now = func() time.Time { return t0.Add(v.after) } + + req := new(dns.Msg) + req.SetQuestion(tt.qname, dns.TypeA) + req.CheckingDisabled = v.cd + req.SetEdns0(512, v.do) + + c.ServeDNS(context.TODO(), rec, req) + if v.fetch { + select { + case <-fetchc: + // Prefetch handler was called. + case <-time.After(time.Second): + t.Fatalf("After %s: want request to trigger a prefetch", v.after) + } + } + if want, got := dns.RcodeSuccess, rec.Rcode; want != got { + t.Errorf("After %s: want rcode %d, got %d", v.after, want, got) + } + if want, got := 1, len(rec.Msg.Answer); want != got { + t.Errorf("After %s: want %d answer RR, got %d", v.after, want, got) + } + if want, got := test.A(v.answer).String(), rec.Msg.Answer[0].String(); want != got { + t.Errorf("After %s: want answer %s, got %s", v.after, want, got) + } + } + }) + } +} + +type verification struct { + after time.Duration + answer string + do bool + cd bool + // fetch defines whether a request is sent to the next handler. + fetch bool +} + +// prefetchHandler is a fake plugin implementation which returns a single A +// record with the given qname and ttl. The returned IP address starts at +// 127.0.0.1 and is incremented on every request. +func prefetchHandler(qname string, ttl int, fetchc chan struct{}) plugin.Handler { + i := 0 + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + i++ + m := new(dns.Msg) + m.SetQuestion(qname, dns.TypeA) + m.Response = true + m.Answer = append(m.Answer, test.A(fmt.Sprintf("%s %d IN A 127.0.0.%d", qname, ttl, i))) + + w.WriteMsg(m) + fetchc <- struct{}{} + return dns.RcodeSuccess, nil + }) +} diff --git a/plugin/cache/setup.go b/plugin/cache/setup.go new file mode 100644 index 0000000..f8278b8 --- /dev/null +++ b/plugin/cache/setup.go @@ -0,0 +1,261 @@ +package cache + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/cache" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("cache") + +func init() { plugin.Register("cache", setup) } + +func setup(c *caddy.Controller) error { + ca, err := cacheParse(c) + if err != nil { + return plugin.Error("cache", err) + } + + c.OnStartup(func() error { + ca.viewMetricLabel = dnsserver.GetConfig(c).ViewName + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + ca.Next = next + return ca + }) + + return nil +} + +func cacheParse(c *caddy.Controller) (*Cache, error) { + ca := New() + + j := 0 + for c.Next() { + if j > 0 { + return nil, plugin.ErrOnce + } + j++ + + // cache [ttl] [zones..] + args := c.RemainingArgs() + if len(args) > 0 { + // first args may be just a number, then it is the ttl, if not it is a zone + ttl, err := strconv.Atoi(args[0]) + if err == nil { + // Reserve 0 (and smaller for future things) + if ttl <= 0 { + return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", ttl) + } + ca.pttl = time.Duration(ttl) * time.Second + ca.nttl = time.Duration(ttl) * time.Second + args = args[1:] + } + } + origins := plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys) + + // Refinements? In an extra block. + for c.NextBlock() { + switch c.Val() { + // first number is cap, second is an new ttl + case Success: + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + pcap, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + ca.pcap = pcap + if len(args) > 1 { + pttl, err := strconv.Atoi(args[1]) + if err != nil { + return nil, err + } + // Reserve 0 (and smaller for future things) + if pttl <= 0 { + return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", pttl) + } + ca.pttl = time.Duration(pttl) * time.Second + if len(args) > 2 { + minpttl, err := strconv.Atoi(args[2]) + if err != nil { + return nil, err + } + // Reserve < 0 + if minpttl < 0 { + return nil, fmt.Errorf("cache min TTL can not be negative: %d", minpttl) + } + ca.minpttl = time.Duration(minpttl) * time.Second + } + } + case Denial: + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + ncap, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + ca.ncap = ncap + if len(args) > 1 { + nttl, err := strconv.Atoi(args[1]) + if err != nil { + return nil, err + } + // Reserve 0 (and smaller for future things) + if nttl <= 0 { + return nil, fmt.Errorf("cache TTL can not be zero or negative: %d", nttl) + } + ca.nttl = time.Duration(nttl) * time.Second + if len(args) > 2 { + minnttl, err := strconv.Atoi(args[2]) + if err != nil { + return nil, err + } + // Reserve < 0 + if minnttl < 0 { + return nil, fmt.Errorf("cache min TTL can not be negative: %d", minnttl) + } + ca.minnttl = time.Duration(minnttl) * time.Second + } + } + case "prefetch": + args := c.RemainingArgs() + if len(args) == 0 || len(args) > 3 { + return nil, c.ArgErr() + } + amount, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + if amount < 0 { + return nil, fmt.Errorf("prefetch amount should be positive: %d", amount) + } + ca.prefetch = amount + + if len(args) > 1 { + dur, err := time.ParseDuration(args[1]) + if err != nil { + return nil, err + } + ca.duration = dur + } + if len(args) > 2 { + pct := args[2] + if x := pct[len(pct)-1]; x != '%' { + return nil, fmt.Errorf("last character of percentage should be `%%`, but is: %q", x) + } + pct = pct[:len(pct)-1] + + num, err := strconv.Atoi(pct) + if err != nil { + return nil, err + } + if num < 10 || num > 90 { + return nil, fmt.Errorf("percentage should fall in range [10, 90]: %d", num) + } + ca.percentage = num + } + + case "serve_stale": + args := c.RemainingArgs() + if len(args) > 2 { + return nil, c.ArgErr() + } + ca.staleUpTo = 1 * time.Hour + if len(args) > 0 { + d, err := time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + if d < 0 { + return nil, errors.New("invalid negative duration for serve_stale") + } + ca.staleUpTo = d + } + ca.verifyStale = false + if len(args) > 1 { + mode := strings.ToLower(args[1]) + if mode != "immediate" && mode != "verify" { + return nil, fmt.Errorf("invalid value for serve_stale refresh mode: %s", mode) + } + ca.verifyStale = mode == "verify" + } + case "servfail": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + d, err := time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + if d < 0 { + return nil, errors.New("invalid negative ttl for servfail") + } + if d > 5*time.Minute { + // RFC 2308 prohibits caching SERVFAIL longer than 5 minutes + return nil, errors.New("caching SERVFAIL responses over 5 minutes is not permitted") + } + ca.failttl = d + case "disable": + // disable [success|denial] [zones]... + args := c.RemainingArgs() + if len(args) < 1 { + return nil, c.ArgErr() + } + + var zones []string + if len(args) > 1 { + for _, z := range args[1:] { // args[1:] define the list of zones to disable + nz := plugin.Name(z).Normalize() + if nz == "" { + return nil, fmt.Errorf("invalid disabled zone: %s", z) + } + zones = append(zones, nz) + } + } else { + // if no zones specified, default to root + zones = []string{"."} + } + + switch args[0] { // args[0] defines which cache to disable + case Denial: + ca.nexcept = zones + case Success: + ca.pexcept = zones + default: + return nil, fmt.Errorf("cache type for disable must be %q or %q", Success, Denial) + } + case "keepttl": + args := c.RemainingArgs() + if len(args) != 0 { + return nil, c.ArgErr() + } + ca.keepttl = true + default: + return nil, c.ArgErr() + } + } + + ca.Zones = origins + ca.zonesMetricLabel = strings.Join(origins, ",") + ca.pcache = cache.New(ca.pcap) + ca.ncache = cache.New(ca.ncap) + } + + return ca, nil +} diff --git a/plugin/cache/setup_test.go b/plugin/cache/setup_test.go new file mode 100644 index 0000000..46ac5bd --- /dev/null +++ b/plugin/cache/setup_test.go @@ -0,0 +1,262 @@ +package cache + +import ( + "fmt" + "testing" + "time" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedNcap int + expectedPcap int + expectedNttl time.Duration + expectedMinNttl time.Duration + expectedPttl time.Duration + expectedMinPttl time.Duration + expectedPrefetch int + }{ + {`cache`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache {}`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 10 + }`, false, defaultCap, 10, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 10 1800 30 + }`, false, defaultCap, 10, maxNTTL, minNTTL, 1800 * time.Second, 30 * time.Second, 0}, + {`cache example.nl { + success 10 + denial 10 15 + }`, false, 10, 10, 15 * time.Second, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 10 + denial 10 15 2 + }`, false, 10, 10, 15 * time.Second, 2 * time.Second, maxTTL, minTTL, 0}, + {`cache 25 example.nl { + success 10 + denial 10 15 + }`, false, 10, 10, 15 * time.Second, minNTTL, 25 * time.Second, minTTL, 0}, + {`cache 25 example.nl { + success 10 + denial 10 15 5 + }`, false, 10, 10, 15 * time.Second, 5 * time.Second, 25 * time.Second, minTTL, 0}, + {`cache aaa example.nl`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache { + prefetch 10 + }`, false, defaultCap, defaultCap, maxNTTL, minNTTL, maxTTL, minTTL, 10}, + + // fails + {`cache example.nl { + success + denial 10 15 + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + success 15 + denial aaa + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache example.nl { + positive 15 + negative aaa + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 0 example.nl`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache -1 example.nl`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 1 example.nl { + positive 0 + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 1 example.nl { + positive 0 + prefetch -1 + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache 1 example.nl { + prefetch 0 blurp + }`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + {`cache + cache`, true, defaultCap, defaultCap, maxTTL, minNTTL, maxTTL, minTTL, 0}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr && err != nil { + continue + } + + if ca.ncap != test.expectedNcap { + t.Errorf("Test %v: Expected ncap %v but found: %v", i, test.expectedNcap, ca.ncap) + } + if ca.pcap != test.expectedPcap { + t.Errorf("Test %v: Expected pcap %v but found: %v", i, test.expectedPcap, ca.pcap) + } + if ca.nttl != test.expectedNttl { + t.Errorf("Test %v: Expected nttl %v but found: %v", i, test.expectedNttl, ca.nttl) + } + if ca.minnttl != test.expectedMinNttl { + t.Errorf("Test %v: Expected minnttl %v but found: %v", i, test.expectedMinNttl, ca.minnttl) + } + if ca.pttl != test.expectedPttl { + t.Errorf("Test %v: Expected pttl %v but found: %v", i, test.expectedPttl, ca.pttl) + } + if ca.minpttl != test.expectedMinPttl { + t.Errorf("Test %v: Expected minpttl %v but found: %v", i, test.expectedMinPttl, ca.minpttl) + } + if ca.prefetch != test.expectedPrefetch { + t.Errorf("Test %v: Expected prefetch %v but found: %v", i, test.expectedPrefetch, ca.prefetch) + } + } +} + +func TestServeStale(t *testing.T) { + tests := []struct { + input string + shouldErr bool + staleUpTo time.Duration + verifyStale bool + }{ + {"serve_stale", false, 1 * time.Hour, false}, + {"serve_stale 20m", false, 20 * time.Minute, false}, + {"serve_stale 1h20m", false, 80 * time.Minute, false}, + {"serve_stale 0m", false, 0, false}, + {"serve_stale 0", false, 0, false}, + {"serve_stale 0 verify", false, 0, true}, + {"serve_stale 0 immediate", false, 0, false}, + {"serve_stale 0 VERIFY", false, 0, true}, + // fails + {"serve_stale 20", true, 0, false}, + {"serve_stale -20m", true, 0, false}, + {"serve_stale aa", true, 0, false}, + {"serve_stale 1m nono", true, 0, false}, + {"serve_stale 0 after nono", true, 0, false}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr && err != nil { + continue + } + if ca.staleUpTo != test.staleUpTo { + t.Errorf("Test %v: Expected stale %v but found: %v", i, test.staleUpTo, ca.staleUpTo) + } + } +} + +func TestServfail(t *testing.T) { + tests := []struct { + input string + shouldErr bool + failttl time.Duration + }{ + {"servfail 1s", false, 1 * time.Second}, + {"servfail 5m", false, 5 * time.Minute}, + {"servfail 0s", false, 0}, + {"servfail 0", false, 0}, + // fails + {"servfail", true, minNTTL}, + {"servfail 6m", true, minNTTL}, + {"servfail 20", true, minNTTL}, + {"servfail -1s", true, minNTTL}, + {"servfail aa", true, minNTTL}, + {"servfail 1m invalid", true, minNTTL}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr && err != nil { + continue + } + if ca.failttl != test.failttl { + t.Errorf("Test %v: Expected stale %v but found: %v", i, test.failttl, ca.staleUpTo) + } + } +} + +func TestDisable(t *testing.T) { + tests := []struct { + input string + shouldErr bool + nexcept []string + pexcept []string + }{ + // positive + {"disable denial example.com example.org", false, []string{"example.com.", "example.org."}, nil}, + {"disable success example.com example.org", false, nil, []string{"example.com.", "example.org."}}, + {"disable denial", false, []string{"."}, nil}, + {"disable success", false, nil, []string{"."}}, + {"disable denial example.com example.org\ndisable success example.com example.org", false, + []string{"example.com.", "example.org."}, []string{"example.com.", "example.org."}}, + // negative + {"disable invalid example.com example.org", true, nil, nil}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr { + continue + } + if fmt.Sprintf("%v", test.nexcept) != fmt.Sprintf("%v", ca.nexcept) { + t.Errorf("Test %v: Expected %v but got: %v", i, test.nexcept, ca.nexcept) + } + if fmt.Sprintf("%v", test.pexcept) != fmt.Sprintf("%v", ca.pexcept) { + t.Errorf("Test %v: Expected %v but got: %v", i, test.pexcept, ca.pexcept) + } + } +} + +func TestKeepttl(t *testing.T) { + tests := []struct { + input string + shouldErr bool + }{ + // positive + {"keepttl", false}, + // negative + {"keepttl arg1", true}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", fmt.Sprintf("cache {\n%s\n}", test.input)) + ca, err := cacheParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + if test.shouldErr { + continue + } + if !ca.keepttl { + t.Errorf("Test %v: Expected keepttl enabled but disabled", i) + } + } +} diff --git a/plugin/cache/spoof_test.go b/plugin/cache/spoof_test.go new file mode 100644 index 0000000..20d7e8d --- /dev/null +++ b/plugin/cache/spoof_test.go @@ -0,0 +1,82 @@ +package cache + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestSpoof(t *testing.T) { + // Send query for example.org, get reply for example.net; should not be cached. + c := New() + c.Next = spoofHandler(true) + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(context.TODO(), rec, req) + + qname := rec.Msg.Question[0].Name + if c.pcache.Len() != 0 { + t.Errorf("Cached %s, while reply had %s", "example.org.", qname) + } + + // qtype + c.Next = spoofHandlerType() + req.SetQuestion("example.org.", dns.TypeMX) + + c.ServeDNS(context.TODO(), rec, req) + + qtype := rec.Msg.Question[0].Qtype + if c.pcache.Len() != 0 { + t.Errorf("Cached %s type %d, while reply had %d", "example.org.", dns.TypeMX, qtype) + } +} + +func TestResponse(t *testing.T) { + // Send query for example.org, get reply for example.net; should not be cached. + c := New() + c.Next = spoofHandler(false) + + req := new(dns.Msg) + req.SetQuestion("example.net.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.ServeDNS(context.TODO(), rec, req) + + if c.pcache.Len() != 0 { + t.Errorf("Cached %s, while reply had response set to %t", "example.net.", rec.Msg.Response) + } +} + +// spoofHandler is a fake plugin implementation which returns a single A records for example.org. The qname in the +// question section is set to example.NET (i.e. they *don't* match). +func spoofHandler(response bool) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.net.", dns.TypeA) + m.Response = response + m.Answer = []dns.RR{test.A("example.org. IN A 127.0.0.53")} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} + +// spoofHandlerType is a fake plugin implementation which returns a single MX records for example.org. The qtype in the +// question section is set to A. +func spoofHandlerType() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + m.Response = true + m.Answer = []dns.RR{test.MX("example.org. IN MX 10 mail.example.org.")} + w.WriteMsg(m) + return dns.RcodeSuccess, nil + }) +} diff --git a/plugin/cancel/README.md b/plugin/cancel/README.md new file mode 100644 index 0000000..64f585a --- /dev/null +++ b/plugin/cancel/README.md @@ -0,0 +1,47 @@ +# cancel + +## Name + +*cancel* - cancels a request's context after 5001 milliseconds. + +## Description + +The *cancel* plugin creates a canceling context for each request. It adds a timeout that gets +triggered after 5001 milliseconds. + +The 5001 number was chosen because the default timeout for DNS clients is 5 seconds, after that they +give up. + +A plugin interested in the cancellation status should call `plugin.Done()` on the context. If the +context was canceled due to a timeout the plugin should not write anything back to the client and +return a value indicating CoreDNS should not either; a zero return value should suffice for that. + +## Syntax + +~~~ txt +cancel [TIMEOUT] +~~~ + +* **TIMEOUT** allows setting a custom timeout. The default timeout is 5001 milliseconds (`5001 ms`) + +## Examples + +~~~ corefile +example.org { + cancel + whoami +} +~~~ + +Or with a custom timeout: + +~~~ corefile +example.org { + cancel 1s + whoami +} +~~~ + +## See Also + +The Go documentation for the context package. diff --git a/plugin/cancel/cancel.go b/plugin/cancel/cancel.go new file mode 100644 index 0000000..23f5de4 --- /dev/null +++ b/plugin/cancel/cancel.go @@ -0,0 +1,66 @@ +// Package cancel implements a plugin adds a canceling context to each request. +package cancel + +import ( + "context" + "fmt" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +func init() { plugin.Register("cancel", setup) } + +func setup(c *caddy.Controller) error { + ca := Cancel{} + + for c.Next() { + args := c.RemainingArgs() + switch len(args) { + case 0: + ca.timeout = 5001 * time.Millisecond + case 1: + dur, err := time.ParseDuration(args[0]) + if err != nil { + return plugin.Error("cancel", fmt.Errorf("invalid duration: %q", args[0])) + } + if dur <= 0 { + return plugin.Error("cancel", fmt.Errorf("invalid negative duration: %q", args[0])) + } + ca.timeout = dur + default: + return plugin.Error("cancel", c.ArgErr()) + } + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + ca.Next = next + return ca + }) + + return nil +} + +// Cancel is a plugin that adds a canceling context to each request's context. +type Cancel struct { + timeout time.Duration + Next plugin.Handler +} + +// ServeDNS implements the plugin.Handler interface. +func (c Cancel) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + + code, err := plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r) + + cancel() + + return code, err +} + +// Name implements the Handler interface. +func (c Cancel) Name() string { return "cancel" } diff --git a/plugin/cancel/cancel_test.go b/plugin/cancel/cancel_test.go new file mode 100644 index 0000000..f775518 --- /dev/null +++ b/plugin/cancel/cancel_test.go @@ -0,0 +1,51 @@ +package cancel + +import ( + "context" + "testing" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +type sleepPlugin struct{} + +func (s sleepPlugin) Name() string { return "sleep" } + +func (s sleepPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + i := 0 + m := new(dns.Msg) + m.SetReply(r) + for { + if plugin.Done(ctx) { + m.Rcode = dns.RcodeBadTime // use BadTime to return something time related + w.WriteMsg(m) + return 0, nil + } + time.Sleep(20 * time.Millisecond) + i++ + if i > 2 { + m.Rcode = dns.RcodeServerFailure + w.WriteMsg(m) + return 0, nil + } + } +} + +func TestCancel(t *testing.T) { + ca := Cancel{Next: sleepPlugin{}, timeout: 20 * time.Millisecond} + ctx := context.Background() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + m := new(dns.Msg) + m.SetQuestion("aaa.example.com.", dns.TypeTXT) + + ca.ServeDNS(ctx, w, m) + if w.Rcode != dns.RcodeBadTime { + t.Error("Expected ServeDNS to be canceled by context") + } +} diff --git a/plugin/cancel/setup_test.go b/plugin/cancel/setup_test.go new file mode 100644 index 0000000..6079ff5 --- /dev/null +++ b/plugin/cancel/setup_test.go @@ -0,0 +1,29 @@ +package cancel + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", `cancel`) + if err := setup(c); err != nil { + t.Errorf("Test 1, expected no errors, but got: %q", err) + } + + c = caddy.NewTestController("dns", `cancel 5s`) + if err := setup(c); err != nil { + t.Errorf("Test 2, expected no errors, but got: %q", err) + } + + c = caddy.NewTestController("dns", `cancel 5`) + if err := setup(c); err == nil { + t.Errorf("Test 3, expected errors, but got none") + } + + c = caddy.NewTestController("dns", `cancel -1s`) + if err := setup(c); err == nil { + t.Errorf("Test 4, expected errors, but got none") + } +} diff --git a/plugin/chaos/README.md b/plugin/chaos/README.md new file mode 100644 index 0000000..9ce5216 --- /dev/null +++ b/plugin/chaos/README.md @@ -0,0 +1,51 @@ +# chaos + +## Name + +*chaos* - allows for responding to TXT queries in the CH class. + +## Description + +This is useful for retrieving version or author information from the server by querying a TXT record +for a special domain name in the CH class. + +## Syntax + +~~~ +chaos [VERSION] [AUTHORS...] +~~~ + +* **VERSION** is the version to return. Defaults to `CoreDNS-<version>`, if not set. +* **AUTHORS** is what authors to return. This defaults to all GitHub handles in the OWNERS files. + +Note that you have to make sure that this plugin will get actual queries for the +following zones: `version.bind`, `version.server`, `authors.bind`, `hostname.bind` and +`id.server`. + +## Examples + +Specify all the zones in full. + +~~~ corefile +version.bind version.server authors.bind hostname.bind id.server { + chaos CoreDNS-001 [email protected] +} +~~~ + +Or just default to `.`: + +~~~ corefile +. { + chaos CoreDNS-001 [email protected] +} +~~~ + +And test with `dig`: + +~~~ txt +% dig @localhost CH TXT version.bind +... +;; ANSWER SECTION: +version.bind. 0 CH TXT "CoreDNS-001" +... +~~~ diff --git a/plugin/chaos/chaos.go b/plugin/chaos/chaos.go new file mode 100644 index 0000000..f4d758a --- /dev/null +++ b/plugin/chaos/chaos.go @@ -0,0 +1,58 @@ +// Package chaos implements a plugin that answer to 'CH version.bind TXT' type queries. +package chaos + +import ( + "context" + "math/rand" + "os" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Chaos allows CoreDNS to reply to CH TXT queries and return author or +// version information. +type Chaos struct { + Next plugin.Handler + Version string + Authors []string +} + +// ServeDNS implements the plugin.Handler interface. +func (c Chaos) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + if state.QClass() != dns.ClassCHAOS || state.QType() != dns.TypeTXT { + return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r) + } + + m := new(dns.Msg) + m.SetReply(r) + + hdr := dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeTXT, Class: dns.ClassCHAOS, Ttl: 0} + switch state.Name() { + default: + return plugin.NextOrFailure(c.Name(), c.Next, ctx, w, r) + case "authors.bind.": + rnd := rand.New(rand.NewSource(time.Now().Unix())) + + for _, i := range rnd.Perm(len(c.Authors)) { + m.Answer = append(m.Answer, &dns.TXT{Hdr: hdr, Txt: []string{c.Authors[i]}}) + } + case "version.bind.", "version.server.": + m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{c.Version}}} + case "hostname.bind.", "id.server.": + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + m.Answer = []dns.RR{&dns.TXT{Hdr: hdr, Txt: []string{trim(hostname)}}} + } + w.WriteMsg(m) + return 0, nil +} + +// Name implements the Handler interface. +func (c Chaos) Name() string { return "chaos" } diff --git a/plugin/chaos/chaos_test.go b/plugin/chaos/chaos_test.go new file mode 100644 index 0000000..e5d4a55 --- /dev/null +++ b/plugin/chaos/chaos_test.go @@ -0,0 +1,80 @@ +package chaos + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestChaos(t *testing.T) { + em := Chaos{ + Version: version, + Authors: []string{"Miek Gieben"}, + } + + tests := []struct { + next plugin.Handler + qname string + qtype uint16 + expectedCode int + expectedReply string + expectedErr error + }{ + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "version.bind", + expectedCode: dns.RcodeSuccess, + expectedReply: version, + expectedErr: nil, + }, + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "authors.bind", + expectedCode: dns.RcodeSuccess, + expectedReply: "Miek Gieben", + expectedErr: nil, + }, + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "authors.bind", + qtype: dns.TypeSRV, + expectedCode: dns.RcodeSuccess, + expectedErr: nil, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + if tc.qtype == 0 { + tc.qtype = dns.TypeTXT + } + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + req.Question[0].Qclass = dns.ClassCHAOS + em.Next = tc.next + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := em.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + if tc.expectedReply != "" { + answer := rec.Msg.Answer[0].(*dns.TXT).Txt[0] + if answer != tc.expectedReply { + t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, answer) + } + } + } +} + +const version = "CoreDNS-001" diff --git a/plugin/chaos/fuzz.go b/plugin/chaos/fuzz.go new file mode 100644 index 0000000..001cf1d --- /dev/null +++ b/plugin/chaos/fuzz.go @@ -0,0 +1,13 @@ +//go:build gofuzz + +package chaos + +import ( + "github.com/coredns/coredns/plugin/pkg/fuzz" +) + +// Fuzz fuzzes cache. +func Fuzz(data []byte) int { + c := Chaos{} + return fuzz.Do(c, data) +} diff --git a/plugin/chaos/log_test.go b/plugin/chaos/log_test.go new file mode 100644 index 0000000..92c98af --- /dev/null +++ b/plugin/chaos/log_test.go @@ -0,0 +1,5 @@ +package chaos + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/chaos/setup.go b/plugin/chaos/setup.go new file mode 100644 index 0000000..ce0eb7a --- /dev/null +++ b/plugin/chaos/setup.go @@ -0,0 +1,66 @@ +//go:generate go run owners_generate.go + +package chaos + +import ( + "sort" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("chaos", setup) } + +func setup(c *caddy.Controller) error { + version, authors, err := parse(c) + if err != nil { + return plugin.Error("chaos", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Chaos{Next: next, Version: version, Authors: authors} + }) + + return nil +} + +func parse(c *caddy.Controller) (string, []string, error) { + // Set here so we pick up AppName and AppVersion that get set in coremain's init(). + chaosVersion = caddy.AppName + "-" + caddy.AppVersion + version := "" + + if c.Next() { + args := c.RemainingArgs() + if len(args) == 0 { + return trim(chaosVersion), Owners, nil + } + if len(args) == 1 { + return trim(args[0]), Owners, nil + } + + version = args[0] + authors := make(map[string]struct{}) + for _, a := range args[1:] { + authors[a] = struct{}{} + } + list := []string{} + for k := range authors { + k = trim(k) // limit size to 255 chars + list = append(list, k) + } + sort.Strings(list) + return version, list, nil + } + + return version, Owners, nil +} + +func trim(s string) string { + if len(s) < 256 { + return s + } + return s[:255] +} + +var chaosVersion string diff --git a/plugin/chaos/setup_test.go b/plugin/chaos/setup_test.go new file mode 100644 index 0000000..2c45d86 --- /dev/null +++ b/plugin/chaos/setup_test.go @@ -0,0 +1,54 @@ +package chaos + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupChaos(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedVersion string // expected version. + expectedAuthor string // expected author (string, although we get a slice). + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + { + `chaos v2`, false, "v2", "", "", + }, + { + `chaos v3 "Miek Gieben"`, false, "v3", "Miek Gieben", "", + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + version, authors, err := parse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + + if !test.shouldErr && version != test.expectedVersion { + t.Errorf("Test %d: Chaos not correctly set for input %s. Expected: %s, actual: %s", i, test.input, test.expectedVersion, version) + } + if !test.shouldErr && authors != nil && test.expectedAuthor != "" { + if authors[0] != test.expectedAuthor { + t.Errorf("Test %d: Chaos not correctly set for input %s. Expected: '%s', actual: '%s'", i, test.input, test.expectedAuthor, authors[0]) + } + } + } +} diff --git a/plugin/chaos/zowners.go b/plugin/chaos/zowners.go new file mode 100644 index 0000000..419ca3c --- /dev/null +++ b/plugin/chaos/zowners.go @@ -0,0 +1,4 @@ +package chaos + +// Owners are all GitHub handlers of all maintainers. +var Owners = []string{"Tantalor93", "bradbeam", "chrisohaver", "darshanime", "dilyevsky", "ekleiner", "greenpau", "ihac", "inigohu", "isolus", "jameshartig", "johnbelamaric", "miekg", "mqasimsarfraz", "nchrisdk", "nitisht", "pmoroney", "rajansandeep", "rdrozhdzh", "rtreffer", "snebel29", "stp-ip", "superq", "varyoo", "ykhr53", "yongtang", "zouyee"} diff --git a/plugin/clouddns/README.md b/plugin/clouddns/README.md new file mode 100644 index 0000000..1e12281 --- /dev/null +++ b/plugin/clouddns/README.md @@ -0,0 +1,73 @@ +# clouddns + +## Name + +*clouddns* - enables serving zone data from GCP Cloud DNS. + +## Description + +The *clouddns* plugin is useful for serving zones from resource record +sets in GCP Cloud DNS. This plugin supports all [Google Cloud DNS +records](https://cloud.google.com/dns/docs/overview#supported_dns_record_types). This plugin can +be used when CoreDNS is deployed on GCP or elsewhere. Note that this plugin accesses the resource +records through the Google Cloud API. For records in a privately hosted zone, it is not necessary to +place CoreDNS and this plugin in the associated VPC network. In fact the private hosted zone could +be created without any associated VPC and this plugin could still access the resource records under +the hosted zone. + +## Syntax + +~~~ txt +clouddns [ZONE:PROJECT_ID:HOSTED_ZONE_NAME...] { + credentials [FILENAME] + fallthrough [ZONES...] +} +~~~ + +* **ZONE** the name of the domain to be accessed. When there are multiple zones with overlapping + domains (private vs. public hosted zone), CoreDNS does the lookup in the given order here. + Therefore, for a non-existing resource record, SOA response will be from the rightmost zone. + +* **PROJECT\_ID** the project ID of the Google Cloud project. + +* **HOSTED\_ZONE\_NAME** the name of the hosted zone that contains the resource record sets to be + accessed. + +* `credentials` is used for reading the credential file from **FILENAME** (normally a .json file). + This field is optional. If this field is not provided then authentication will be done automatically, + e.g., through environmental variable `GOOGLE_APPLICATION_CREDENTIALS`. Please see + Google Cloud's [authentication method](https://cloud.google.com/docs/authentication) for more details. + +* `fallthrough` If zone matches and no record can be generated, pass request to the next plugin. + If **[ZONES...]** is omitted, then fallthrough happens for all zones for which the plugin is + authoritative. If specific zones are listed (for example `in-addr.arpa` and `ip6.arpa`), then + only queries for those zones will be subject to fallthrough. + +## Examples + +Enable clouddns with implicit GCP credentials and resolve CNAMEs via 10.0.0.1: + +~~~ txt +example.org { + clouddns example.org.:gcp-example-project:example-zone + forward . 10.0.0.1 +} +~~~ + +Enable clouddns with fallthrough: + +~~~ txt +example.org { + clouddns example.org.:gcp-example-project:example-zone example.com.:gcp-example-project:example-zone-2 { + fallthrough example.gov. + } +} +~~~ + +Enable clouddns with multiple hosted zones with the same domain: + +~~~ txt +. { + clouddns example.org.:gcp-example-project:example-zone example.com.:gcp-example-project:other-example-zone +} +~~~ diff --git a/plugin/clouddns/clouddns.go b/plugin/clouddns/clouddns.go new file mode 100644 index 0000000..0e31a40 --- /dev/null +++ b/plugin/clouddns/clouddns.go @@ -0,0 +1,227 @@ +// Package clouddns implements a plugin that returns resource records +// from GCP Cloud DNS. +package clouddns + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + gcp "google.golang.org/api/dns/v1" +) + +// CloudDNS is a plugin that returns RR from GCP Cloud DNS. +type CloudDNS struct { + Next plugin.Handler + Fall fall.F + + zoneNames []string + client gcpDNS + upstream *upstream.Upstream + + zMu sync.RWMutex + zones zones +} + +type zone struct { + projectName string + zoneName string + z *file.Zone + dns string +} + +type zones map[string][]*zone + +// New reads from the keys map which uses domain names as its key and a colon separated +// string of project name and hosted zone name lists as its values, validates +// that each domain name/zone id pair does exist, and returns a new *CloudDNS. +// In addition to this, upstream is passed for doing recursive queries against CNAMEs. +// Returns error if it cannot verify any given domain name/zone id pair. +func New(ctx context.Context, c gcpDNS, keys map[string][]string, up *upstream.Upstream) (*CloudDNS, error) { + zones := make(map[string][]*zone, len(keys)) + zoneNames := make([]string, 0, len(keys)) + for dnsName, hostedZoneDetails := range keys { + for _, hostedZone := range hostedZoneDetails { + ss := strings.SplitN(hostedZone, ":", 2) + if len(ss) != 2 { + return nil, errors.New("either project or zone name missing") + } + err := c.zoneExists(ss[0], ss[1]) + if err != nil { + return nil, err + } + fqdnDNSName := dns.Fqdn(dnsName) + if _, ok := zones[fqdnDNSName]; !ok { + zoneNames = append(zoneNames, fqdnDNSName) + } + zones[fqdnDNSName] = append(zones[fqdnDNSName], &zone{projectName: ss[0], zoneName: ss[1], dns: fqdnDNSName, z: file.NewZone(fqdnDNSName, "")}) + } + } + return &CloudDNS{ + client: c, + zoneNames: zoneNames, + zones: zones, + upstream: up, + }, nil +} + +// Run executes first update, spins up an update forever-loop. +// Returns error if first update fails. +func (h *CloudDNS) Run(ctx context.Context) error { + if err := h.updateZones(ctx); err != nil { + return err + } + go func() { + delay := 1 * time.Minute + timer := time.NewTimer(delay) + defer timer.Stop() + for { + timer.Reset(delay) + select { + case <-ctx.Done(): + log.Debugf("Breaking out of CloudDNS update loop for %v: %v", h.zoneNames, ctx.Err()) + return + case <-timer.C: + if err := h.updateZones(ctx); err != nil && ctx.Err() == nil /* Don't log error if ctx expired. */ { + log.Errorf("Failed to update zones %v: %v", h.zoneNames, err) + } + } + } + }() + return nil +} + +// ServeDNS implements the plugin.Handler interface. +func (h *CloudDNS) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + + zName := plugin.Zones(h.zoneNames).Matches(qname) + if zName == "" { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + z, ok := h.zones[zName] // ok true if we are authoritative for the zone + if !ok || z == nil { + return dns.RcodeServerFailure, nil + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + var result file.Result + + for _, hostedZone := range z { + h.zMu.RLock() + m.Answer, m.Ns, m.Extra, result = hostedZone.z.Lookup(ctx, state, qname) + h.zMu.RUnlock() + + // Take the answer if it's non-empty OR if there is another + // record type exists for this name (NODATA). + if len(m.Answer) != 0 || result == file.NoData { + break + } + } + + if len(m.Answer) == 0 && result != file.NoData && h.Fall.Through(qname) { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + switch result { + case file.Success: + case file.NoData: + case file.NameError: + m.Rcode = dns.RcodeNameError + case file.Delegation: + m.Authoritative = false + case file.ServerFailure: + return dns.RcodeServerFailure, nil + } + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +func updateZoneFromRRS(rrs *gcp.ResourceRecordSetsListResponse, z *file.Zone) error { + for _, rr := range rrs.Rrsets { + var rfc1035 string + var r dns.RR + var err error + for _, value := range rr.Rrdatas { + if rr.Type == "CNAME" || rr.Type == "PTR" { + value = dns.Fqdn(value) + } + // Assemble RFC 1035 conforming record to pass into dns scanner. + rfc1035 = fmt.Sprintf("%s %d IN %s %s", dns.Fqdn(rr.Name), rr.Ttl, rr.Type, value) + r, err = dns.NewRR(rfc1035) + if err != nil { + return fmt.Errorf("failed to parse resource record: %v", err) + } + + err = z.Insert(r) + if err != nil { + return fmt.Errorf("failed to insert record: %v", err) + } + } + } + return nil +} + +// updateZones re-queries resource record sets for each zone and updates the +// zone object. +// Returns error if any zones error'ed out, but waits for other zones to +// complete first. +func (h *CloudDNS) updateZones(ctx context.Context) error { + errc := make(chan error) + defer close(errc) + for zName, z := range h.zones { + go func(zName string, z []*zone) { + var err error + var rrListResponse *gcp.ResourceRecordSetsListResponse + defer func() { + errc <- err + }() + + for i, hostedZone := range z { + newZ := file.NewZone(zName, "") + newZ.Upstream = h.upstream + rrListResponse, err = h.client.listRRSets(ctx, hostedZone.projectName, hostedZone.zoneName) + if err != nil { + err = fmt.Errorf("failed to list resource records for %v:%v:%v from gcp: %v", zName, hostedZone.projectName, hostedZone.zoneName, err) + return + } + updateZoneFromRRS(rrListResponse, newZ) + + h.zMu.Lock() + (*z[i]).z = newZ + h.zMu.Unlock() + } + }(zName, z) + } + // Collect errors (if any). This will also sync on all zones updates + // completion. + var errs []string + for i := 0; i < len(h.zones); i++ { + err := <-errc + if err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) != 0 { + return fmt.Errorf("errors updating zones: %v", errs) + } + return nil +} + +// Name implements the Handler interface. +func (h *CloudDNS) Name() string { return "clouddns" } diff --git a/plugin/clouddns/clouddns_test.go b/plugin/clouddns/clouddns_test.go new file mode 100644 index 0000000..829aa71 --- /dev/null +++ b/plugin/clouddns/clouddns_test.go @@ -0,0 +1,327 @@ +package clouddns + +import ( + "context" + "errors" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + gcp "google.golang.org/api/dns/v1" +) + +type fakeGCPClient struct { + *gcp.Service +} + +func (c fakeGCPClient) zoneExists(projectName, hostedZoneName string) error { + return nil +} + +func (c fakeGCPClient) listRRSets(ctx context.Context, projectName, hostedZoneName string) (*gcp.ResourceRecordSetsListResponse, error) { + if projectName == "bad-project" || hostedZoneName == "bad-zone" { + return nil, errors.New("the 'parameters.managedZone' resource named 'bad-zone' does not exist") + } + + var rr []*gcp.ResourceRecordSet + + if hostedZoneName == "sample-zone-1" { + rr = []*gcp.ResourceRecordSet{ + { + Name: "example.org.", + Ttl: 300, + Type: "A", + Rrdatas: []string{"1.2.3.4"}, + }, + { + Name: "www.example.org", + Ttl: 300, + Type: "A", + Rrdatas: []string{"1.2.3.4"}, + }, + { + Name: "*.www.example.org", + Ttl: 300, + Type: "CNAME", + Rrdatas: []string{"www.example.org"}, + }, + { + Name: "example.org.", + Ttl: 300, + Type: "AAAA", + Rrdatas: []string{"2001:db8:85a3::8a2e:370:7334"}, + }, + { + Name: "sample.example.org", + Ttl: 300, + Type: "CNAME", + Rrdatas: []string{"example.org"}, + }, + { + Name: "example.org.", + Ttl: 300, + Type: "PTR", + Rrdatas: []string{"ptr.example.org."}, + }, + { + Name: "org.", + Ttl: 300, + Type: "SOA", + Rrdatas: []string{"ns-cloud-c1.googledomains.com. cloud-dns-hostmaster.google.com. 1 21600 300 259200 300"}, + }, + { + Name: "com.", + Ttl: 300, + Type: "NS", + Rrdatas: []string{"ns-cloud-c4.googledomains.com."}, + }, + { + Name: "split-example.gov.", + Ttl: 300, + Type: "A", + Rrdatas: []string{"1.2.3.4"}, + }, + { + Name: "swag.", + Ttl: 300, + Type: "YOLO", + Rrdatas: []string{"foobar"}, + }, + } + } else { + rr = []*gcp.ResourceRecordSet{ + { + Name: "split-example.org.", + Ttl: 300, + Type: "A", + Rrdatas: []string{"1.2.3.4"}, + }, + { + Name: "other-example.org.", + Ttl: 300, + Type: "A", + Rrdatas: []string{"3.5.7.9"}, + }, + { + Name: "org.", + Ttl: 300, + Type: "SOA", + Rrdatas: []string{"ns-cloud-e1.googledomains.com. cloud-dns-hostmaster.google.com. 1 21600 300 259200 300"}, + }, + { + Name: "_dummy._tcp.example.org.", + Ttl: 300, + Type: "SRV", + Rrdatas: []string{ + "0 0 5269 split-example.org", + "0 0 5269 other-example.org", + }, + }, + } + } + + return &gcp.ResourceRecordSetsListResponse{Rrsets: rr}, nil +} + +func TestCloudDNS(t *testing.T) { + ctx := context.Background() + + r, err := New(ctx, fakeGCPClient{}, map[string][]string{"bad.": {"bad-project:bad-zone"}}, &upstream.Upstream{}) + if err != nil { + t.Fatalf("Failed to create Cloud DNS: %v", err) + } + if err = r.Run(ctx); err == nil { + t.Fatalf("Expected errors for zone bad.") + } + + r, err = New(ctx, fakeGCPClient{}, map[string][]string{"org.": {"sample-project-1:sample-zone-2", "sample-project-1:sample-zone-1"}, "gov.": {"sample-project-1:sample-zone-2", "sample-project-1:sample-zone-1"}}, &upstream.Upstream{}) + if err != nil { + t.Fatalf("Failed to create Cloud DNS: %v", err) + } + r.Fall = fall.Zero + r.Fall.SetZonesFromArgs([]string{"gov."}) + r.Next = test.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + m := new(dns.Msg) + rcode := dns.RcodeServerFailure + if qname == "example.gov." { + m.SetReply(r) + rr, err := dns.NewRR("example.gov. 300 IN A 2.4.6.8") + if err != nil { + t.Fatalf("Failed to create Resource Record: %v", err) + } + m.Answer = []dns.RR{rr} + + m.Authoritative = true + rcode = dns.RcodeSuccess + } + + m.SetRcode(r, rcode) + w.WriteMsg(m) + return rcode, nil + }) + err = r.Run(ctx) + if err != nil { + t.Fatalf("Failed to initialize Cloud DNS: %v", err) + } + + tests := []struct { + qname string + qtype uint16 + wantRetCode int + wantAnswer []string // ownernames for the records in the additional section. + wantMsgRCode int + wantNS []string + expectedErr error + }{ + // 0. example.org A found - success. + { + qname: "example.org", + qtype: dns.TypeA, + wantAnswer: []string{"example.org. 300 IN A 1.2.3.4"}, + }, + // 1. example.org AAAA found - success. + { + qname: "example.org", + qtype: dns.TypeAAAA, + wantAnswer: []string{"example.org. 300 IN AAAA 2001:db8:85a3::8a2e:370:7334"}, + }, + // 2. exampled.org PTR found - success. + { + qname: "example.org", + qtype: dns.TypePTR, + wantAnswer: []string{"example.org. 300 IN PTR ptr.example.org."}, + }, + // 3. sample.example.org points to example.org CNAME. + // Query must return both CNAME and A recs. + { + qname: "sample.example.org", + qtype: dns.TypeA, + wantAnswer: []string{ + "sample.example.org. 300 IN CNAME example.org.", + "example.org. 300 IN A 1.2.3.4", + }, + }, + // 4. Explicit CNAME query for sample.example.org. + // Query must return just CNAME. + { + qname: "sample.example.org", + qtype: dns.TypeCNAME, + wantAnswer: []string{"sample.example.org. 300 IN CNAME example.org."}, + }, + // 5. Explicit SOA query for example.org. + { + qname: "example.org", + qtype: dns.TypeNS, + wantNS: []string{"org. 300 IN SOA ns-cloud-c1.googledomains.com. cloud-dns-hostmaster.google.com. 1 21600 300 259200 300"}, + }, + // 6. AAAA query for split-example.org must return NODATA. + { + qname: "split-example.gov", + qtype: dns.TypeAAAA, + wantRetCode: dns.RcodeSuccess, + wantNS: []string{"org. 300 IN SOA ns-cloud-c1.googledomains.com. cloud-dns-hostmaster.google.com. 1 21600 300 259200 300"}, + }, + // 7. Zone not configured. + { + qname: "badexample.com", + qtype: dns.TypeA, + wantRetCode: dns.RcodeServerFailure, + wantMsgRCode: dns.RcodeServerFailure, + }, + // 8. No record found. Return SOA record. + { + qname: "bad.org", + qtype: dns.TypeA, + wantRetCode: dns.RcodeSuccess, + wantMsgRCode: dns.RcodeNameError, + wantNS: []string{"org. 300 IN SOA ns-cloud-c1.googledomains.com. cloud-dns-hostmaster.google.com. 1 21600 300 259200 300"}, + }, + // 9. No record found. Fallthrough. + { + qname: "example.gov", + qtype: dns.TypeA, + wantAnswer: []string{"example.gov. 300 IN A 2.4.6.8"}, + }, + // 10. other-zone.example.org is stored in a different hosted zone. success + { + qname: "other-example.org", + qtype: dns.TypeA, + wantAnswer: []string{"other-example.org. 300 IN A 3.5.7.9"}, + }, + // 11. split-example.org only has A record. Expect NODATA. + { + qname: "split-example.org", + qtype: dns.TypeAAAA, + wantNS: []string{"org. 300 IN SOA ns-cloud-e1.googledomains.com. cloud-dns-hostmaster.google.com. 1 21600 300 259200 300"}, + }, + // 12. *.www.example.org is a wildcard CNAME to www.example.org. + { + qname: "a.www.example.org", + qtype: dns.TypeA, + wantAnswer: []string{ + "a.www.example.org. 300 IN CNAME www.example.org.", + "www.example.org. 300 IN A 1.2.3.4", + }, + }, + // 13. example.org SRV found with 2 answers - success. + { + qname: "_dummy._tcp.example.org.", + qtype: dns.TypeSRV, + wantAnswer: []string{ + "_dummy._tcp.example.org. 300 IN SRV 0 0 5269 split-example.org.", + "_dummy._tcp.example.org. 300 IN SRV 0 0 5269 other-example.org.", + }, + }, + } + + for ti, tc := range tests { + req := new(dns.Msg) + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := r.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Fatalf("Test %d: Expected error %v, but got %v", ti, tc.expectedErr, err) + } + if code != tc.wantRetCode { + t.Fatalf("Test %d: Expected returned status code %s, but got %s", ti, dns.RcodeToString[tc.wantRetCode], dns.RcodeToString[code]) + } + + if tc.wantMsgRCode != rec.Msg.Rcode { + t.Errorf("Test %d: Unexpected msg status code. Want: %s, got: %s", ti, dns.RcodeToString[tc.wantMsgRCode], dns.RcodeToString[rec.Msg.Rcode]) + } + + if len(tc.wantAnswer) != len(rec.Msg.Answer) { + t.Errorf("Test %d: Unexpected number of Answers. Want: %d, got: %d", ti, len(tc.wantAnswer), len(rec.Msg.Answer)) + } else { + for i, gotAnswer := range rec.Msg.Answer { + if gotAnswer.String() != tc.wantAnswer[i] { + t.Errorf("Test %d: Unexpected answer.\nWant:\n\t%s\nGot:\n\t%s", ti, tc.wantAnswer[i], gotAnswer) + } + } + } + + if len(tc.wantNS) != len(rec.Msg.Ns) { + t.Errorf("Test %d: Unexpected NS number. Want: %d, got: %d", ti, len(tc.wantNS), len(rec.Msg.Ns)) + } else { + for i, ns := range rec.Msg.Ns { + got, ok := ns.(*dns.SOA) + if !ok { + t.Errorf("Test %d: Unexpected NS type. Want: SOA, got: %v", ti, reflect.TypeOf(got)) + } + if got.String() != tc.wantNS[i] { + t.Errorf("Test %d: Unexpected NS.\nWant: %v\nGot: %v", ti, tc.wantNS[i], got) + } + } + } + } +} diff --git a/plugin/clouddns/gcp.go b/plugin/clouddns/gcp.go new file mode 100644 index 0000000..b02ab2b --- /dev/null +++ b/plugin/clouddns/gcp.go @@ -0,0 +1,40 @@ +package clouddns + +import ( + "context" + + gcp "google.golang.org/api/dns/v1" +) + +type gcpDNS interface { + zoneExists(projectName, hostedZoneName string) error + listRRSets(ctx context.Context, projectName, hostedZoneName string) (*gcp.ResourceRecordSetsListResponse, error) +} + +type gcpClient struct { + *gcp.Service +} + +// zoneExists is a wrapper method around `gcp.Service.ManagedZones.Get` +// it checks if the provided zone name for a given project exists. +func (c gcpClient) zoneExists(projectName, hostedZoneName string) error { + _, err := c.ManagedZones.Get(projectName, hostedZoneName).Do() + if err != nil { + return err + } + return nil +} + +// listRRSets is a wrapper method around `gcp.Service.ResourceRecordSets.List` +// it fetches and returns the record sets for a hosted zone. +func (c gcpClient) listRRSets(ctx context.Context, projectName, hostedZoneName string) (*gcp.ResourceRecordSetsListResponse, error) { + req := c.ResourceRecordSets.List(projectName, hostedZoneName) + var rs []*gcp.ResourceRecordSet + if err := req.Pages(ctx, func(page *gcp.ResourceRecordSetsListResponse) error { + rs = append(rs, page.Rrsets...) + return nil + }); err != nil { + return nil, err + } + return &gcp.ResourceRecordSetsListResponse{Rrsets: rs}, nil +} diff --git a/plugin/clouddns/log_test.go b/plugin/clouddns/log_test.go new file mode 100644 index 0000000..148635b --- /dev/null +++ b/plugin/clouddns/log_test.go @@ -0,0 +1,5 @@ +package clouddns + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/clouddns/setup.go b/plugin/clouddns/setup.go new file mode 100644 index 0000000..cfd7eec --- /dev/null +++ b/plugin/clouddns/setup.go @@ -0,0 +1,108 @@ +package clouddns + +import ( + "context" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/fall" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/upstream" + + gcp "google.golang.org/api/dns/v1" + "google.golang.org/api/option" +) + +var log = clog.NewWithPlugin("clouddns") + +func init() { plugin.Register("clouddns", setup) } + +// exposed for testing +var f = func(ctx context.Context, opt option.ClientOption) (gcpDNS, error) { + var err error + var client *gcp.Service + if opt != nil { + client, err = gcp.NewService(ctx, opt) + } else { + // if credentials file is not provided in the Corefile + // authenticate the client using env variables + client, err = gcp.NewService(ctx) + } + return gcpClient{client}, err +} + +func setup(c *caddy.Controller) error { + for c.Next() { + keyPairs := map[string]struct{}{} + keys := map[string][]string{} + + var fall fall.F + up := upstream.New() + + args := c.RemainingArgs() + + for i := 0; i < len(args); i++ { + parts := strings.SplitN(args[i], ":", 3) + if len(parts) != 3 { + return plugin.Error("clouddns", c.Errf("invalid zone %q", args[i])) + } + dnsName, projectName, hostedZone := parts[0], parts[1], parts[2] + if dnsName == "" || projectName == "" || hostedZone == "" { + return plugin.Error("clouddns", c.Errf("invalid zone %q", args[i])) + } + if _, ok := keyPairs[args[i]]; ok { + return plugin.Error("clouddns", c.Errf("conflict zone %q", args[i])) + } + + keyPairs[args[i]] = struct{}{} + keys[dnsName] = append(keys[dnsName], projectName+":"+hostedZone) + } + + var opt option.ClientOption + for c.NextBlock() { + switch c.Val() { + case "upstream": + c.RemainingArgs() + case "credentials": + if c.NextArg() { + opt = option.WithCredentialsFile(c.Val()) + } else { + return plugin.Error("clouddns", c.ArgErr()) + } + case "fallthrough": + fall.SetZonesFromArgs(c.RemainingArgs()) + default: + return plugin.Error("clouddns", c.Errf("unknown property %q", c.Val())) + } + } + + ctx, cancel := context.WithCancel(context.Background()) + client, err := f(ctx, opt) + if err != nil { + cancel() + return err + } + + h, err := New(ctx, client, keys, up) + if err != nil { + cancel() + return plugin.Error("clouddns", c.Errf("failed to create plugin: %v", err)) + } + h.Fall = fall + + if err := h.Run(ctx); err != nil { + cancel() + return plugin.Error("clouddns", c.Errf("failed to initialize plugin: %v", err)) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + h.Next = next + return h + }) + c.OnShutdown(func() error { cancel(); return nil }) + } + + return nil +} diff --git a/plugin/clouddns/setup_test.go b/plugin/clouddns/setup_test.go new file mode 100644 index 0000000..ae2262e --- /dev/null +++ b/plugin/clouddns/setup_test.go @@ -0,0 +1,49 @@ +package clouddns + +import ( + "context" + "testing" + + "github.com/coredns/caddy" + + "google.golang.org/api/option" +) + +func TestSetupCloudDNS(t *testing.T) { + f = func(ctx context.Context, opt option.ClientOption) (gcpDNS, error) { + return fakeGCPClient{}, nil + } + + tests := []struct { + body string + expectedError bool + }{ + {`clouddns`, false}, + {`clouddns :`, true}, + {`clouddns ::`, true}, + {`clouddns example.org.:example-project:zone-name`, false}, + {`clouddns example.org.:example-project:zone-name { }`, false}, + {`clouddns example.org.:example-project: { }`, true}, + {`clouddns example.org.:example-project:zone-name { }`, false}, + {`clouddns example.org.:example-project:zone-name { wat +}`, true}, + {`clouddns example.org.:example-project:zone-name { + fallthrough +}`, false}, + {`clouddns example.org.:example-project:zone-name { + credentials +}`, true}, + {`clouddns example.org.:example-project:zone-name example.org.:example-project:zone-name { + }`, true}, + + {`clouddns example.org { + }`, true}, + } + + for _, test := range tests { + c := caddy.NewTestController("dns", test.body) + if err := setup(c); (err == nil) == test.expectedError { + t.Errorf("Unexpected errors: %v", err) + } + } +} diff --git a/plugin/debug/README.md b/plugin/debug/README.md new file mode 100644 index 0000000..4376723 --- /dev/null +++ b/plugin/debug/README.md @@ -0,0 +1,51 @@ +# debug + +## Name + +*debug* - disables the automatic recovery upon a crash so that you'll get a nice stack trace. + +## Description + +Normally CoreDNS will recover from panics; using *debug* inhibits this. The main use of *debug* is +to help in testing. A side effect of using *debug* is that `log.Debug` and `log.Debugf` messages +will be printed to standard output. + +Note that the *errors* plugin (if loaded) will also set a `recover`, negating this setting. + +Enabling this plugin is process-wide: enabling *debug* in at least one server block enables +debug mode globally. + +## Syntax + +~~~ txt +debug +~~~ + +Some plugins will send debug log DNS messages. This is done in the following format: + +~~~ +debug: 000000 00 0a 01 00 00 01 00 00 00 00 00 01 07 65 78 61 +debug: 000010 6d 70 6c 65 05 6c 6f 63 61 6c 00 00 01 00 01 00 +debug: 000020 00 29 10 00 00 00 80 00 00 00 +debug: 00002a +~~~ + +Using `text2pcap` (part of Wireshark), this can be converted back to binary, with the following +command line: `text2pcap -i 17 -u 53,53`, where 17 is the protocol (UDP) and 53 are the ports. These +ports allow Wireshark to detect these packets as DNS messages. + +Each plugin can decide whether to dump messages to aid in debugging. + +## Examples + +Disable the ability to recover from crashes and show debug logging: + +~~~ corefile +. { + debug +} +~~~ + +## See Also + +<https://www.wireshark.org/docs/man-pages/text2pcap.html>. diff --git a/plugin/debug/debug.go b/plugin/debug/debug.go new file mode 100644 index 0000000..7fb6861 --- /dev/null +++ b/plugin/debug/debug.go @@ -0,0 +1,22 @@ +package debug + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("debug", setup) } + +func setup(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + for c.Next() { + if c.NextArg() { + return plugin.Error("debug", c.ArgErr()) + } + config.Debug = true + } + + return nil +} diff --git a/plugin/debug/debug_test.go b/plugin/debug/debug_test.go new file mode 100644 index 0000000..71ebf37 --- /dev/null +++ b/plugin/debug/debug_test.go @@ -0,0 +1,44 @@ +package debug + +import ( + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestDebug(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedDebug bool + }{ + // positive + { + `debug`, false, true, + }, + // negative + { + `debug off`, true, false, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + cfg := dnsserver.GetConfig(c) + + if test.shouldErr && err == nil { + t.Fatalf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Fatalf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + } + if cfg.Debug != test.expectedDebug { + t.Fatalf("Test %d: Expected debug to be: %t, but got: %t, input: %s", i, test.expectedDebug, cfg.Debug, test.input) + } + } +} diff --git a/plugin/debug/log_test.go b/plugin/debug/log_test.go new file mode 100644 index 0000000..6e256db --- /dev/null +++ b/plugin/debug/log_test.go @@ -0,0 +1,5 @@ +package debug + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/debug/pcap.go b/plugin/debug/pcap.go new file mode 100644 index 0000000..493478a --- /dev/null +++ b/plugin/debug/pcap.go @@ -0,0 +1,72 @@ +package debug + +import ( + "bytes" + "fmt" + + "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" +) + +// Hexdump converts the dns message m to a hex dump Wireshark can import. +// See https://www.wireshark.org/docs/man-pages/text2pcap.html. +// This output looks like this: +// +// 00000 dc bd 01 00 00 01 00 00 00 00 00 01 07 65 78 61 +// 000010 6d 70 6c 65 05 6c 6f 63 61 6c 00 00 01 00 01 00 +// 000020 00 29 10 00 00 00 80 00 00 00 +// 00002a +// +// Hexdump will use log.Debug to write the dump to the log, each line +// is prefixed with 'debug: ' so the data can be easily extracted. +// +// msg will prefix the pcap dump. +func Hexdump(m *dns.Msg, v ...interface{}) { + if !log.D.Value() { + return + } + + buf, _ := m.Pack() + if len(buf) == 0 { + return + } + + out := "\n" + string(hexdump(buf)) + v = append(v, out) + log.Debug(v...) +} + +// Hexdumpf dumps a DNS message as Hexdump, but allows a format string. +func Hexdumpf(m *dns.Msg, format string, v ...interface{}) { + if !log.D.Value() { + return + } + + buf, _ := m.Pack() + if len(buf) == 0 { + return + } + + format += "\n%s" + v = append(v, hexdump(buf)) + log.Debugf(format, v...) +} + +func hexdump(data []byte) []byte { + b := new(bytes.Buffer) + + newline := "" + for i := 0; i < len(data); i++ { + if i%16 == 0 { + fmt.Fprintf(b, "%s%s%06x", newline, prefix, i) + newline = "\n" + } + fmt.Fprintf(b, " %02x", data[i]) + } + fmt.Fprintf(b, "\n%s%06x", prefix, len(data)) + + return b.Bytes() +} + +const prefix = "debug: " diff --git a/plugin/debug/pcap_test.go b/plugin/debug/pcap_test.go new file mode 100644 index 0000000..6b263c8 --- /dev/null +++ b/plugin/debug/pcap_test.go @@ -0,0 +1,73 @@ +package debug + +import ( + "bytes" + "fmt" + golog "log" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" +) + +func msg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.local.", dns.TypeA) + m.SetEdns0(4096, true) + m.Id = 10 + return m +} + +func TestNoDebug(t *testing.T) { + // Must come first, because set log.D.Set() which is impossible to undo. + var f bytes.Buffer + golog.SetOutput(&f) + + str := "Hi There!" + Hexdumpf(msg(), "%s %d", str, 10) + if len(f.Bytes()) != 0 { + t.Errorf("Expected no output, got %d bytes", len(f.Bytes())) + } +} + +func ExampleHexdump() { + buf, _ := msg().Pack() + h := hexdump(buf) + fmt.Println(string(h)) + + // Output: + // debug: 000000 00 0a 01 00 00 01 00 00 00 00 00 01 07 65 78 61 + // debug: 000010 6d 70 6c 65 05 6c 6f 63 61 6c 00 00 01 00 01 00 + // debug: 000020 00 29 10 00 00 00 80 00 00 00 + // debug: 00002a +} + +func TestHexdump(t *testing.T) { + var f bytes.Buffer + golog.SetOutput(&f) + log.D.Set() + + str := "Hi There!" + Hexdump(msg(), str) + logged := f.String() + + if !strings.Contains(logged, "[DEBUG] "+str) { + t.Errorf("The string %s, is not contained in the logged output: %s", str, logged) + } +} + +func TestHexdumpf(t *testing.T) { + var f bytes.Buffer + golog.SetOutput(&f) + log.D.Set() + + str := "Hi There!" + Hexdumpf(msg(), "%s %d", str, 10) + logged := f.String() + + if !strings.Contains(logged, "[DEBUG] "+fmt.Sprintf("%s %d", str, 10)) { + t.Errorf("The string %s %d, is not contained in the logged output: %s", str, 10, logged) + } +} diff --git a/plugin/deprecated/setup.go b/plugin/deprecated/setup.go new file mode 100644 index 0000000..64caa0c --- /dev/null +++ b/plugin/deprecated/setup.go @@ -0,0 +1,34 @@ +// Package deprecated is used when we deprecated plugin. In plugin.cfg just go from +// +// startup:github.com/coredns/caddy/startupshutdown +// +// To: +// +// startup:deprecated +// +// And things should work as expected. This means starting CoreDNS will fail with an error. We can only +// point to the release notes to details what next steps a user should take. I.e. there is no way to add this +// to the error generated. +package deprecated + +import ( + "errors" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin" +) + +// removed has the names of the plugins that need to error on startup. +var removed = []string{""} + +func setup(c *caddy.Controller) error { + c.Next() + x := c.Val() + return plugin.Error(x, errors.New("this plugin has been deprecated")) +} + +func init() { + for _, plug := range removed { + plugin.Register(plug, setup) + } +} diff --git a/plugin/dns64/README.md b/plugin/dns64/README.md new file mode 100644 index 0000000..7975574 --- /dev/null +++ b/plugin/dns64/README.md @@ -0,0 +1,106 @@ +# dns64 + +## Name + +*dns64* - enables DNS64 IPv6 transition mechanism. + +## Description + +The *dns64* plugin will when asked for a domain's AAAA records, but only finds A records, +synthesizes the AAAA records from the A records. + +The synthesis is *only* performed **if the query came in via IPv6**. + +This translation is for IPv6-only networks that have [NAT64](https://en.wikipedia.org/wiki/NAT64). + +## Syntax + +~~~ +dns64 [PREFIX] +~~~ + +* **PREFIX** defines a custom prefix instead of the default `64:ff9b::/96`. + +Or use this slightly longer form with more options: + +~~~ +dns64 [PREFIX] { + [translate_all] + prefix PREFIX + [allow_ipv4] +} +~~~ + +* `prefix` specifies any local IPv6 prefix to use, instead of the well known prefix (64:ff9b::/96) +* `translate_all` translates all queries, including responses that have AAAA results. +* `allow_ipv4` Allow translating queries if they come in over IPv4, default is IPv6 only translation. + +## Examples + +Translate with the default well known prefix. Applies to all queries (if they came in over IPv6). + +~~~ +. { + dns64 +} +~~~ + +Use a custom prefix. + +~~~ corefile +. { + dns64 64:1337::/96 +} +~~~ + +Or +~~~ corefile +. { + dns64 { + prefix 64:1337::/96 + } +} +~~~ + +Enable translation even if an existing AAAA record is present. + +~~~ corefile +. { + dns64 { + translate_all + } +} +~~~ + +Apply translation even to the requests which arrived over IPv4 network. Warning, the `allow_ipv4` feature will apply +translations to requests coming from dual-stack clients. This means that a request for a client that sends an `AAAA` +that would normal result in an `NXDOMAIN` would get a translated result. +This may cause unwanted IPv6 dns64 traffic when a dualstack client would normally use the result of an `A` record request. + +~~~ corefile +. { + dns64 { + allow_ipv4 + } +} +~~~ + +## Metrics + +If monitoring is enabled (via the _prometheus_ plugin) then the following metrics are exported: + +- `coredns_dns64_requests_translated_total{server}` - counter of DNS requests translated + +The `server` label is explained in the _prometheus_ plugin documentation. + +## Bugs + +Not all features required by DNS64 are implemented, only basic AAAA synthesis. + +* Support "mapping of separate IPv4 ranges to separate IPv6 prefixes" +* Resolve PTR records +* Make resolver DNSSEC aware. See: [RFC 6147 Section 3](https://tools.ietf.org/html/rfc6147#section-3) + +## See Also + +See [RFC 6147](https://tools.ietf.org/html/rfc6147) for more information on the DNS64 mechanism. diff --git a/plugin/dns64/dns64.go b/plugin/dns64/dns64.go new file mode 100644 index 0000000..9f426eb --- /dev/null +++ b/plugin/dns64/dns64.go @@ -0,0 +1,208 @@ +// Package dns64 implements a plugin that performs DNS64. +// +// See: RFC 6147 (https://tools.ietf.org/html/rfc6147) +package dns64 + +import ( + "context" + "errors" + "net" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/nonwriter" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// UpstreamInt wraps the Upstream API for dependency injection during testing +type UpstreamInt interface { + Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) +} + +// DNS64 performs DNS64. +type DNS64 struct { + Next plugin.Handler + Prefix *net.IPNet + TranslateAll bool // Not comply with 5.1.1 + AllowIPv4 bool + Upstream UpstreamInt +} + +// ServeDNS implements the plugin.Handler interface. +func (d *DNS64) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + // Don't proxy if we don't need to. + if !d.requestShouldIntercept(&request.Request{W: w, Req: r}) { + return d.Next.ServeDNS(ctx, w, r) + } + + // Pass the request to the next plugin in the chain, but intercept the response. + nw := nonwriter.New(w) + origRc, origErr := d.Next.ServeDNS(ctx, nw, r) + if nw.Msg == nil { // somehow we didn't get a response (or raw bytes were written) + return origRc, origErr + } + + // If the response doesn't need DNS64, short-circuit. + if !d.responseShouldDNS64(nw.Msg) { + w.WriteMsg(nw.Msg) + return origRc, origErr + } + + // otherwise do the actual DNS64 request and response synthesis + msg, err := d.DoDNS64(ctx, w, r, nw.Msg) + if err != nil { + // err means we weren't able to even issue the A request + // to CoreDNS upstream + return dns.RcodeServerFailure, err + } + + RequestsTranslatedCount.WithLabelValues(metrics.WithServer(ctx)).Inc() + w.WriteMsg(msg) + return msg.MsgHdr.Rcode, nil +} + +// Name implements the Handler interface. +func (d *DNS64) Name() string { return "dns64" } + +// requestShouldIntercept returns true if the request represents one that is eligible +// for DNS64 rewriting: +// 1. The request came in over IPv6 or the 'allow_ipv4' option is set +// 2. The request is of type AAAA +// 3. The request is of class INET +func (d *DNS64) requestShouldIntercept(req *request.Request) bool { + // Make sure that request came in over IPv4 unless AllowIPv4 option is enabled. + // Translating requests without taking into consideration client (source) IP might be problematic in dual-stack networks. + if !d.AllowIPv4 && req.Family() == 1 { + return false + } + + // Do not modify if question is not AAAA or not of class IN. See RFC 6147 5.1 + return req.QType() == dns.TypeAAAA && req.QClass() == dns.ClassINET +} + +// responseShouldDNS64 returns true if the response indicates we should attempt +// DNS64 rewriting: +// 1. The response has no valid (RFC 5.1.4) AAAA records (RFC 5.1.1) +// 2. The response code (RCODE) is not 3 (Name Error) (RFC 5.1.2) +// +// Note that requestShouldIntercept must also have been true, so the request +// is known to be of type AAAA. +func (d *DNS64) responseShouldDNS64(origResponse *dns.Msg) bool { + ty, _ := response.Typify(origResponse, time.Now().UTC()) + + // Handle NameError normally. See RFC 6147 5.1.2 + // All other error types are "equivalent" to empty response + if ty == response.NameError { + return false + } + + // If we've configured to always translate, well, then always translate. + if d.TranslateAll { + return true + } + + // if response includes AAAA record, no need to rewrite + for _, rr := range origResponse.Answer { + if rr.Header().Rrtype == dns.TypeAAAA { + return false + } + } + return true +} + +// DoDNS64 takes an (empty) response to an AAAA question, issues the A request, +// and synthesizes the answer. Returns the response message, or error on internal failure. +func (d *DNS64) DoDNS64(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, origResponse *dns.Msg) (*dns.Msg, error) { + req := request.Request{W: w, Req: r} // req is unused + resp, err := d.Upstream.Lookup(ctx, req, req.Name(), dns.TypeA) + if err != nil { + return nil, err + } + out := d.Synthesize(r, origResponse, resp) + return out, nil +} + +// Synthesize merges the AAAA response and the records from the A response +func (d *DNS64) Synthesize(origReq, origResponse, resp *dns.Msg) *dns.Msg { + ret := dns.Msg{} + ret.SetReply(origReq) + + // persist truncated state of AAAA response + ret.Truncated = resp.Truncated + + // 5.3.2: DNS64 MUST pass the additional section unchanged + ret.Extra = resp.Extra + ret.Ns = resp.Ns + + // 5.1.7: The TTL is the minimum of the A RR and the SOA RR. If SOA is + // unknown, then the TTL is the minimum of A TTL and 600 + SOATtl := uint32(600) // Default NS record TTL + for _, ns := range origResponse.Ns { + if ns.Header().Rrtype == dns.TypeSOA { + SOATtl = ns.Header().Ttl + } + } + + ret.Answer = make([]dns.RR, 0, len(resp.Answer)) + // convert A records to AAAA records + for _, rr := range resp.Answer { + header := rr.Header() + // 5.3.3: All other RR's MUST be returned unchanged + if header.Rrtype != dns.TypeA { + ret.Answer = append(ret.Answer, rr) + continue + } + + aaaa, _ := to6(d.Prefix, rr.(*dns.A).A) + + // ttl is min of SOA TTL and A TTL + ttl := SOATtl + if rr.Header().Ttl < ttl { + ttl = rr.Header().Ttl + } + + // Replace A answer with a DNS64 AAAA answer + ret.Answer = append(ret.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: header.Name, + Rrtype: dns.TypeAAAA, + Class: header.Class, + Ttl: ttl, + }, + AAAA: aaaa, + }) + } + return &ret +} + +// to6 takes a prefix and IPv4 address and returns an IPv6 address according to RFC 6052. +func to6(prefix *net.IPNet, addr net.IP) (net.IP, error) { + addr = addr.To4() + if addr == nil { + return nil, errors.New("not a valid IPv4 address") + } + + n, _ := prefix.Mask.Size() + // Assumes prefix has been validated during setup + v6 := make([]byte, 16) + i, j := 0, 0 + + for ; i < n/8; i++ { + v6[i] = prefix.IP[i] + } + for ; i < 8; i, j = i+1, j+1 { + v6[i] = addr[j] + } + if i == 8 { + i++ + } + for ; j < 4; i, j = i+1, j+1 { + v6[i] = addr[j] + } + + return v6, nil +} diff --git a/plugin/dns64/dns64_test.go b/plugin/dns64/dns64_test.go new file mode 100644 index 0000000..a294721 --- /dev/null +++ b/plugin/dns64/dns64_test.go @@ -0,0 +1,556 @@ +package dns64 + +import ( + "context" + "fmt" + "net" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func To6(prefix, address string) (net.IP, error) { + _, pref, _ := net.ParseCIDR(prefix) + addr := net.ParseIP(address) + + return to6(pref, addr) +} + +func TestRequestShouldIntercept(t *testing.T) { + tests := []struct { + name string + allowIpv4 bool + remoteIP string + msg *dns.Msg + want bool + }{ + { + name: "should intercept request from IPv6 network - AAAA - IN", + allowIpv4: true, + remoteIP: "::1", + msg: new(dns.Msg).SetQuestion("example.com", dns.TypeAAAA), + want: true, + }, + { + name: "should intercept request from IPv4 network - AAAA - IN", + allowIpv4: true, + remoteIP: "127.0.0.1", + msg: new(dns.Msg).SetQuestion("example.com", dns.TypeAAAA), + want: true, + }, + { + name: "should not intercept request from IPv4 network - AAAA - IN", + allowIpv4: false, + remoteIP: "127.0.0.1", + msg: new(dns.Msg).SetQuestion("example.com", dns.TypeAAAA), + want: false, + }, + { + name: "should not intercept request from IPv6 network - A - IN", + allowIpv4: false, + remoteIP: "::1", + msg: new(dns.Msg).SetQuestion("example.com", dns.TypeA), + want: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h := DNS64{AllowIPv4: tc.allowIpv4} + rec := dnstest.NewRecorder(&test.ResponseWriter{RemoteIP: tc.remoteIP}) + r := request.Request{W: rec, Req: tc.msg} + + actual := h.requestShouldIntercept(&r) + + if actual != tc.want { + t.Fatalf("Expected %v, but got %v", tc.want, actual) + } + }) + } +} + +func TestTo6(t *testing.T) { + v6, err := To6("64:ff9b::/96", "64.64.64.64") + if err != nil { + t.Error(err) + } + if v6.String() != "64:ff9b::4040:4040" { + t.Errorf("%d", v6) + } + + v6, err = To6("64:ff9b::/64", "64.64.64.64") + if err != nil { + t.Error(err) + } + if v6.String() != "64:ff9b::40:4040:4000:0" { + t.Errorf("%d", v6) + } + + v6, err = To6("64:ff9b::/56", "64.64.64.64") + if err != nil { + t.Error(err) + } + if v6.String() != "64:ff9b:0:40:40:4040::" { + t.Errorf("%d", v6) + } + + v6, err = To6("64::/32", "64.64.64.64") + if err != nil { + t.Error(err) + } + if v6.String() != "64:0:4040:4040::" { + t.Errorf("%d", v6) + } +} + +func TestResponseShould(t *testing.T) { + var tests = []struct { + resp dns.Msg + translateAll bool + expected bool + }{ + // If there's an AAAA record, then no + { + resp: dns.Msg{ + MsgHdr: dns.MsgHdr{ + Rcode: dns.RcodeSuccess, + }, + Answer: []dns.RR{ + test.AAAA("example.com. IN AAAA ::1"), + }, + }, + expected: false, + }, + // If there's no AAAA, then true + { + resp: dns.Msg{ + MsgHdr: dns.MsgHdr{ + Rcode: dns.RcodeSuccess, + }, + Ns: []dns.RR{ + test.SOA("example.com. IN SOA foo bar 1 1 1 1 1"), + }, + }, + expected: true, + }, + // Failure, except NameError, should be true + { + resp: dns.Msg{ + MsgHdr: dns.MsgHdr{ + Rcode: dns.RcodeNotImplemented, + }, + Ns: []dns.RR{ + test.SOA("example.com. IN SOA foo bar 1 1 1 1 1"), + }, + }, + expected: true, + }, + // NameError should be false + { + resp: dns.Msg{ + MsgHdr: dns.MsgHdr{ + Rcode: dns.RcodeNameError, + }, + Ns: []dns.RR{ + test.SOA("example.com. IN SOA foo bar 1 1 1 1 1"), + }, + }, + expected: false, + }, + // If there's an AAAA record, but translate_all is configured, then yes + { + resp: dns.Msg{ + MsgHdr: dns.MsgHdr{ + Rcode: dns.RcodeSuccess, + }, + Answer: []dns.RR{ + test.AAAA("example.com. IN AAAA ::1"), + }, + }, + translateAll: true, + expected: true, + }, + } + + d := DNS64{} + + for idx, tc := range tests { + t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) { + d.TranslateAll = tc.translateAll + actual := d.responseShouldDNS64(&tc.resp) + if actual != tc.expected { + t.Fatalf("Expected %v got %v", tc.expected, actual) + } + }) + } +} + +func TestDNS64(t *testing.T) { + var cases = []struct { + // a brief summary of the test case + name string + + // the request + req *dns.Msg + + // the initial response from the "downstream" server + initResp *dns.Msg + + // A response to provide + aResp *dns.Msg + + // the expected ultimate result + resp *dns.Msg + }{ + { + // no AAAA record, yes A record. Do DNS64 + name: "standard flow", + req: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + RecursionDesired: true, + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + initResp: &dns.Msg{ //success, no answers + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 70 IN SOA foo bar 1 1 1 1 1")}, + }, + aResp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 43, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.A("example.com. 60 IN A 192.0.2.42"), + test.A("example.com. 5000 IN A 192.0.2.43"), + }, + }, + + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.AAAA("example.com. 60 IN AAAA 64:ff9b::192.0.2.42"), + // override RR ttl to SOA ttl, since it's lower + test.AAAA("example.com. 70 IN AAAA 64:ff9b::192.0.2.43"), + }, + }, + }, + { + // name exists, but has neither A nor AAAA record + name: "a empty", + req: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + RecursionDesired: true, + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + initResp: &dns.Msg{ //success, no answers + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 3600 IN SOA foo bar 1 7200 900 1209600 86400")}, + }, + aResp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 43, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 3600 IN SOA foo bar 1 7200 900 1209600 86400")}, + }, + + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 3600 IN SOA foo bar 1 7200 900 1209600 86400")}, + Answer: []dns.RR{}, // just to make comparison happy + }, + }, + { + // Query error other than NameError + name: "non-nxdomain error", + req: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + RecursionDesired: true, + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + initResp: &dns.Msg{ // failure + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeRefused, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + aResp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 43, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.A("example.com. 60 IN A 192.0.2.42"), + test.A("example.com. 5000 IN A 192.0.2.43"), + }, + }, + + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.AAAA("example.com. 60 IN AAAA 64:ff9b::192.0.2.42"), + test.AAAA("example.com. 600 IN AAAA 64:ff9b::192.0.2.43"), + }, + }, + }, + { + // nxdomain (NameError): don't even try an A request. + name: "nxdomain", + req: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + RecursionDesired: true, + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + initResp: &dns.Msg{ // failure + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeNameError, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 3600 IN SOA foo bar 1 7200 900 1209600 86400")}, + }, + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeNameError, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 3600 IN SOA foo bar 1 7200 900 1209600 86400")}, + }, + }, + { + // AAAA record exists + name: "AAAA record", + req: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + RecursionDesired: true, + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + + initResp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.AAAA("example.com. 60 IN AAAA ::1"), + test.AAAA("example.com. 5000 IN AAAA ::2"), + }, + }, + + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.AAAA("example.com. 60 IN AAAA ::1"), + test.AAAA("example.com. 5000 IN AAAA ::2"), + }, + }, + }, + { + // no AAAA records, A record response truncated. + name: "truncated A response", + req: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + RecursionDesired: true, + Opcode: dns.OpcodeQuery, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + }, + initResp: &dns.Msg{ //success, no answers + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Ns: []dns.RR{test.SOA("example.com. 70 IN SOA foo bar 1 1 1 1 1")}, + }, + aResp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 43, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Truncated: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.A("example.com. 60 IN A 192.0.2.42"), + test.A("example.com. 5000 IN A 192.0.2.43"), + }, + }, + + resp: &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: 42, + Opcode: dns.OpcodeQuery, + RecursionDesired: true, + Truncated: true, + Rcode: dns.RcodeSuccess, + Response: true, + }, + Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}}, + Answer: []dns.RR{ + test.AAAA("example.com. 60 IN AAAA 64:ff9b::192.0.2.42"), + // override RR ttl to SOA ttl, since it's lower + test.AAAA("example.com. 70 IN AAAA 64:ff9b::192.0.2.43"), + }, + }, + }, + } + + _, pfx, _ := net.ParseCIDR("64:ff9b::/96") + + for idx, tc := range cases { + t.Run(fmt.Sprintf("%d_%s", idx, tc.name), func(t *testing.T) { + d := DNS64{ + Next: &fakeHandler{t, tc.initResp}, + Prefix: pfx, + Upstream: &fakeUpstream{t, tc.req.Question[0].Name, tc.aResp}, + } + + rec := dnstest.NewRecorder(&test.ResponseWriter{RemoteIP: "::1"}) + rc, err := d.ServeDNS(context.Background(), rec, tc.req) + if err != nil { + t.Fatal(err) + } + actual := rec.Msg + if actual.Rcode != rc { + t.Fatalf("ServeDNS should return real result code %q != %q", actual.Rcode, rc) + } + + if !reflect.DeepEqual(actual, tc.resp) { + t.Fatalf("Final answer should match expected %q != %q", actual, tc.resp) + } + }) + } +} + +type fakeHandler struct { + t *testing.T + reply *dns.Msg +} + +func (fh *fakeHandler) ServeDNS(_ context.Context, w dns.ResponseWriter, _ *dns.Msg) (int, error) { + if fh.reply == nil { + panic("fakeHandler ServeDNS with nil reply") + } + w.WriteMsg(fh.reply) + + return fh.reply.Rcode, nil +} +func (fh *fakeHandler) Name() string { + return "fake" +} + +type fakeUpstream struct { + t *testing.T + qname string + resp *dns.Msg +} + +func (fu *fakeUpstream) Lookup(_ context.Context, _ request.Request, name string, typ uint16) (*dns.Msg, error) { + if fu.qname == "" { + fu.t.Fatalf("Unexpected A lookup for %s", name) + } + if name != fu.qname { + fu.t.Fatalf("Wrong A lookup for %s, expected %s", name, fu.qname) + } + + if typ != dns.TypeA { + fu.t.Fatalf("Wrong lookup type %d, expected %d", typ, dns.TypeA) + } + + return fu.resp, nil +} diff --git a/plugin/dns64/metrics.go b/plugin/dns64/metrics.go new file mode 100644 index 0000000..9552316 --- /dev/null +++ b/plugin/dns64/metrics.go @@ -0,0 +1,18 @@ +package dns64 + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // RequestsTranslatedCount is the number of DNS requests translated by dns64. + RequestsTranslatedCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: pluginName, + Name: "requests_translated_total", + Help: "Counter of DNS requests translated by dns64.", + }, []string{"server"}) +) diff --git a/plugin/dns64/setup.go b/plugin/dns64/setup.go new file mode 100644 index 0000000..5e06187 --- /dev/null +++ b/plugin/dns64/setup.go @@ -0,0 +1,92 @@ +package dns64 + +import ( + "net" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/upstream" +) + +const pluginName = "dns64" + +func init() { plugin.Register(pluginName, setup) } + +func setup(c *caddy.Controller) error { + dns64, err := dns64Parse(c) + if err != nil { + return plugin.Error(pluginName, err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + dns64.Next = next + return dns64 + }) + + return nil +} + +func dns64Parse(c *caddy.Controller) (*DNS64, error) { + _, defaultPref, _ := net.ParseCIDR("64:ff9b::/96") + dns64 := &DNS64{ + Upstream: upstream.New(), + Prefix: defaultPref, + } + + for c.Next() { + args := c.RemainingArgs() + if len(args) == 1 { + pref, err := parsePrefix(c, args[0]) + + if err != nil { + return nil, err + } + dns64.Prefix = pref + continue + } + if len(args) > 0 { + return nil, c.ArgErr() + } + + for c.NextBlock() { + switch c.Val() { + case "prefix": + if !c.NextArg() { + return nil, c.ArgErr() + } + pref, err := parsePrefix(c, c.Val()) + + if err != nil { + return nil, err + } + dns64.Prefix = pref + case "translate_all": + dns64.TranslateAll = true + case "allow_ipv4": + dns64.AllowIPv4 = true + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + } + return dns64, nil +} + +func parsePrefix(c *caddy.Controller, addr string) (*net.IPNet, error) { + _, pref, err := net.ParseCIDR(addr) + if err != nil { + return nil, err + } + + // Test for valid prefix + n, total := pref.Mask.Size() + if total != 128 { + return nil, c.Errf("invalid netmask %d IPv6 address: %q", total, pref) + } + if n%8 != 0 || n < 32 || n > 96 { + return nil, c.Errf("invalid prefix length %q", pref) + } + + return pref, nil +} diff --git a/plugin/dns64/setup_test.go b/plugin/dns64/setup_test.go new file mode 100644 index 0000000..e7d13f4 --- /dev/null +++ b/plugin/dns64/setup_test.go @@ -0,0 +1,153 @@ +package dns64 + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupDns64(t *testing.T) { + tests := []struct { + inputUpstreams string + shouldErr bool + wantPrefix string + wantAllowIpv4 bool + }{ + { + `dns64`, + false, + "64:ff9b::/96", + false, + }, + { + `dns64 64:dead::/96`, + false, + "64:dead::/96", + false, + }, + { + `dns64 { + translate_all + }`, + false, + "64:ff9b::/96", + false, + }, + { + `dns64`, + false, + "64:ff9b::/96", + false, + }, + { + `dns64 { + prefix 64:ff9b::/96 + }`, + false, + "64:ff9b::/96", + false, + }, + { + `dns64 { + prefix 64:ff9b::/32 + }`, + false, + "64:ff9b::/32", + false, + }, + { + `dns64 { + prefix 64:ff9b::/52 + }`, + true, + "64:ff9b::/52", + false, + }, + { + `dns64 { + prefix 64:ff9b::/104 + }`, + true, + "64:ff9b::/104", + false, + }, + { + `dns64 { + prefix 8.8.8.8/24 + }`, + true, + "8.8.9.9/24", + false, + }, + { + `dns64 { + prefix 64:ff9b::/96 + }`, + false, + "64:ff9b::/96", + false, + }, + { + `dns64 { + prefix 2002:ac12:b083::/96 + }`, + false, + "2002:ac12:b083::/96", + false, + }, + { + `dns64 { + prefix 2002:c0a8:a88a::/48 + }`, + false, + "2002:c0a8:a88a::/48", + false, + }, + { + `dns64 foobar { + prefix 64:ff9b::/96 + }`, + true, + "64:ff9b::/96", + false, + }, + { + `dns64 foobar`, + true, + "64:ff9b::/96", + false, + }, + { + `dns64 { + foobar + }`, + true, + "64:ff9b::/96", + false, + }, + { + `dns64 { + allow_ipv4 + }`, + false, + "64:ff9b::/96", + true, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputUpstreams) + dns64, err := dns64Parse(c) + if (err != nil) != test.shouldErr { + t.Errorf("Test %d expected %v error, got %v for %s", i+1, test.shouldErr, err, test.inputUpstreams) + } + if err == nil { + if dns64.Prefix.String() != test.wantPrefix { + t.Errorf("Test %d expected prefix %s, got %v", i+1, test.wantPrefix, dns64.Prefix.String()) + } + if dns64.AllowIPv4 != test.wantAllowIpv4 { + t.Errorf("Test %d expected prefix %v, got %v", i+1, test.wantAllowIpv4, dns64.AllowIPv4) + } + } + } +} diff --git a/plugin/dnssec/README.md b/plugin/dnssec/README.md new file mode 100644 index 0000000..00766a1 --- /dev/null +++ b/plugin/dnssec/README.md @@ -0,0 +1,87 @@ +# dnssec + +## Name + +*dnssec* - enables on-the-fly DNSSEC signing of served data. + +## Description + +With *dnssec*, any reply that doesn't (or can't) do DNSSEC will get signed on the fly. Authenticated +denial of existence is implemented with NSEC black lies. Using ECDSA as an algorithm is preferred as +this leads to smaller signatures (compared to RSA). NSEC3 is *not* supported. + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ +dnssec [ZONES... ] { + key file KEY... + cache_capacity CAPACITY +} +~~~ + +The signing behavior depends on the keys specified. If multiple keys are specified of which there is +at least one key with the SEP bit set and at least one key with the SEP bit unset, signing will happen +in split ZSK/KSK mode. DNSKEY records will be signed with all keys that have the SEP bit set. All other +records will be signed with all keys that do not have the SEP bit set. + +In any other case, each specified key will be treated as a CSK (common signing key), forgoing the +ZSK/KSK split. All signing operations are done online. +Authenticated denial of existence is implemented with NSEC black lies. Using ECDSA as an algorithm +is preferred as this leads to smaller signatures (compared to RSA). NSEC3 is *not* supported. + +As the *dnssec* plugin can't see the original TTL of the RRSets it signs, it will always use 3600s +as the value. + +If multiple *dnssec* plugins are specified in the same zone, the last one specified will be +used. + +* **ZONES** zones that should be signed. If empty, the zones from the configuration block + are used. + +* `key file` indicates that **KEY** file(s) should be read from disk. When multiple keys are specified, RRsets + will be signed with all keys. Generating a key can be done with `dnssec-keygen`: `dnssec-keygen -a + ECDSAP256SHA256 <zonename>`. A key created for zone *A* can be safely used for zone *B*. The name of the + key file can be specified in one of the following formats + + * basename of the generated key `Kexample.org+013+45330` + * generated public key `Kexample.org+013+45330.key` + * generated private key `Kexample.org+013+45330.private` + +* `cache_capacity` indicates the capacity of the cache. The dnssec plugin uses a cache to store + RRSIGs. The default for **CAPACITY** is 10000. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metrics are exported: + +* `coredns_dnssec_cache_entries{server, type}` - total elements in the cache, type is "signature". +* `coredns_dnssec_cache_hits_total{server}` - Counter of cache hits. +* `coredns_dnssec_cache_misses_total{server}` - Counter of cache misses. + +The label `server` indicated the server handling the request, see the *metrics* plugin for details. + +## Examples + +Sign responses for `example.org` with the key "Kexample.org.+013+45330.key". + +~~~ corefile +example.org { + dnssec { + key file Kexample.org.+013+45330 + } + whoami +} +~~~ + +Sign responses for a kubernetes zone with the key "Kcluster.local+013+45129.key". + +~~~ +cluster.local { + kubernetes + dnssec { + key file Kcluster.local+013+45129 + } +} +~~~ diff --git a/plugin/dnssec/black_lies.go b/plugin/dnssec/black_lies.go new file mode 100644 index 0000000..d01fa7c --- /dev/null +++ b/plugin/dnssec/black_lies.go @@ -0,0 +1,79 @@ +package dnssec + +import ( + "strings" + + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// nsec returns an NSEC useful for NXDOMAIN responses. +// See https://tools.ietf.org/html/draft-valsorda-dnsop-black-lies-00 +// For example, a request for the non-existing name a.example.com would +// cause the following NSEC record to be generated: +// +// a.example.com. 3600 IN NSEC \000.a.example.com. ( RRSIG NSEC ... ) +// +// This inturn makes every NXDOMAIN answer a NODATA one, don't forget to flip +// the header rcode to NOERROR. +func (d Dnssec) nsec(state request.Request, mt response.Type, ttl, incep, expir uint32, server string) ([]dns.RR, error) { + nsec := &dns.NSEC{} + nsec.Hdr = dns.RR_Header{Name: state.QName(), Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeNSEC} + nsec.NextDomain = "\\000." + state.QName() + if state.QName() == "." { + nsec.NextDomain = "\\000." // If You want to play as root server + } + if state.Name() == state.Zone { + nsec.TypeBitMap = filter18(state.QType(), apexBitmap, mt) + } else if mt == response.Delegation || state.QType() == dns.TypeDS { + nsec.TypeBitMap = delegationBitmap[:] + if mt == response.Delegation { + labels := dns.SplitDomainName(state.QName()) + labels[0] += "\\000" + nsec.NextDomain = strings.Join(labels, ".") + "." + } + } else { + nsec.TypeBitMap = filter14(state.QType(), zoneBitmap, mt) + } + + sigs, err := d.sign([]dns.RR{nsec}, state.Zone, ttl, incep, expir, server) + if err != nil { + return nil, err + } + + return append(sigs, nsec), nil +} + +// The NSEC bit maps we return. +var ( + delegationBitmap = [...]uint16{dns.TypeA, dns.TypeNS, dns.TypeHINFO, dns.TypeTXT, dns.TypeAAAA, dns.TypeLOC, dns.TypeSRV, dns.TypeCERT, dns.TypeSSHFP, dns.TypeRRSIG, dns.TypeNSEC, dns.TypeTLSA, dns.TypeHIP, dns.TypeOPENPGPKEY, dns.TypeSPF} + zoneBitmap = [...]uint16{dns.TypeA, dns.TypeHINFO, dns.TypeTXT, dns.TypeAAAA, dns.TypeLOC, dns.TypeSRV, dns.TypeCERT, dns.TypeSSHFP, dns.TypeRRSIG, dns.TypeNSEC, dns.TypeTLSA, dns.TypeHIP, dns.TypeOPENPGPKEY, dns.TypeSPF} + apexBitmap = [...]uint16{dns.TypeA, dns.TypeNS, dns.TypeSOA, dns.TypeHINFO, dns.TypeMX, dns.TypeTXT, dns.TypeAAAA, dns.TypeLOC, dns.TypeSRV, dns.TypeCERT, dns.TypeSSHFP, dns.TypeRRSIG, dns.TypeNSEC, dns.TypeDNSKEY, dns.TypeTLSA, dns.TypeHIP, dns.TypeOPENPGPKEY, dns.TypeSPF} +) + +// filter14 filters out t from bitmap (if it exists). If mt is not an NODATA response, just return the entire bitmap. +func filter14(t uint16, bitmap [14]uint16, mt response.Type) []uint16 { + if mt != response.NoData && mt != response.NameError || t == dns.TypeNSEC { + return zoneBitmap[:] + } + for i := range bitmap { + if bitmap[i] == t { + return append(bitmap[:i], bitmap[i+1:]...) + } + } + return zoneBitmap[:] // make a slice +} + +func filter18(t uint16, bitmap [18]uint16, mt response.Type) []uint16 { + if mt != response.NoData && mt != response.NameError || t == dns.TypeNSEC { + return apexBitmap[:] + } + for i := range bitmap { + if bitmap[i] == t { + return append(bitmap[:i], bitmap[i+1:]...) + } + } + return apexBitmap[:] // make a slice +} diff --git a/plugin/dnssec/black_lies_bitmap_test.go b/plugin/dnssec/black_lies_bitmap_test.go new file mode 100644 index 0000000..4e9a10c --- /dev/null +++ b/plugin/dnssec/black_lies_bitmap_test.go @@ -0,0 +1,64 @@ +package dnssec + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +const server = "dns//." + +func TestBlackLiesBitmapNoData(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"example.org."}) + defer rm1() + defer rm2() + + m := testTLSAMsg() + state := request.Request{Req: m, Zone: "example.org."} + m = d.Sign(state, time.Now().UTC(), server) + + var nsec *dns.NSEC + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeNSEC { + nsec = r.(*dns.NSEC) + } + } + for _, b := range nsec.TypeBitMap { + if b == dns.TypeTLSA { + t.Errorf("Type TLSA should not be present in the type bitmap: %v", nsec.TypeBitMap) + } + } +} +func TestBlackLiesBitmapNameError(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"example.org."}) + defer rm1() + defer rm2() + + m := testTLSAMsg() + m.Rcode = dns.RcodeNameError // change to name error + state := request.Request{Req: m, Zone: "example.org."} + m = d.Sign(state, time.Now().UTC(), server) + + var nsec *dns.NSEC + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeNSEC { + nsec = r.(*dns.NSEC) + } + } + for _, b := range nsec.TypeBitMap { + if b == dns.TypeTLSA { + t.Errorf("Type TLSA should not be present in the type bitmap: %v", nsec.TypeBitMap) + } + } +} + +func testTLSAMsg() *dns.Msg { + return &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, + Question: []dns.Question{{Name: "25._tcp.example.org.", Qclass: dns.ClassINET, Qtype: dns.TypeTLSA}}, + Ns: []dns.RR{test.SOA("example.org. 1800 IN SOA linode.example.org. miek.example.org. 1461471181 14400 3600 604800 14400")}, + } +} diff --git a/plugin/dnssec/black_lies_test.go b/plugin/dnssec/black_lies_test.go new file mode 100644 index 0000000..de381e5 --- /dev/null +++ b/plugin/dnssec/black_lies_test.go @@ -0,0 +1,258 @@ +package dnssec + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestZoneSigningBlackLies(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testNxdomainMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Ns, 2) { + t.Errorf("Authority section should have 2 sigs") + } + var nsec *dns.NSEC + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeNSEC { + nsec = r.(*dns.NSEC) + } + } + if m.Rcode != dns.RcodeSuccess { + t.Errorf("Expected rcode %d, got %d", dns.RcodeSuccess, m.Rcode) + } + if nsec == nil { + t.Fatalf("Expected NSEC, got none") + } + if nsec.Hdr.Name != "ww.miek.nl." { + t.Errorf("Expected %s, got %s", "ww.miek.nl.", nsec.Hdr.Name) + } + if nsec.NextDomain != "\\000.ww.miek.nl." { + t.Errorf("Expected %s, got %s", "\\000.ww.miek.nl.", nsec.NextDomain) + } +} + +func TestBlackLiesNoError(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testSuccessMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + + if m.Rcode != dns.RcodeSuccess { + t.Errorf("Expected rcode %d, got %d", dns.RcodeSuccess, m.Rcode) + } + + if len(m.Answer) != 2 { + t.Errorf("Answer section should have 2 RRs") + } + sig, txt := false, false + for _, rr := range m.Answer { + if _, ok := rr.(*dns.RRSIG); ok { + sig = true + } + if _, ok := rr.(*dns.TXT); ok { + txt = true + } + } + if !sig || !txt { + t.Errorf("Expected RRSIG and TXT in answer section") + } +} + +func TestBlackLiesApexNsec(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testNsecMsg() + m.SetQuestion("miek.nl.", dns.TypeNSEC) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if len(m.Ns) > 0 { + t.Error("Authority section should be empty") + } + if len(m.Answer) != 2 { + t.Errorf("Answer section should have 2 RRs") + } + sig, nsec := false, false + for _, rr := range m.Answer { + if _, ok := rr.(*dns.RRSIG); ok { + sig = true + } + if rnsec, ok := rr.(*dns.NSEC); ok { + nsec = true + var bitpresent uint + for _, typeBit := range rnsec.TypeBitMap { + switch typeBit { + case dns.TypeSOA: + bitpresent |= 4 + case dns.TypeNSEC: + bitpresent |= 1 + case dns.TypeRRSIG: + bitpresent |= 2 + } + } + if bitpresent != 7 { + t.Error("NSEC must have SOA, RRSIG and NSEC in its bitmap") + } + } + } + if !sig || !nsec { + t.Errorf("Expected RRSIG and NSEC in answer section") + } +} + +func TestBlackLiesNsec(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testNsecMsg() + m.SetQuestion("www.miek.nl.", dns.TypeNSEC) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if len(m.Ns) > 0 { + t.Error("Authority section should be empty") + } + if len(m.Answer) != 2 { + t.Errorf("Answer section should have 2 RRs") + } + sig, nsec := false, false + for _, rr := range m.Answer { + if _, ok := rr.(*dns.RRSIG); ok { + sig = true + } + if rnsec, ok := rr.(*dns.NSEC); ok { + nsec = true + var bitpresent uint + for _, typeBit := range rnsec.TypeBitMap { + switch typeBit { + case dns.TypeNSEC: + bitpresent |= 1 + case dns.TypeRRSIG: + bitpresent |= 2 + } + } + if bitpresent != 3 { + t.Error("NSEC must have RRSIG and NSEC in its bitmap") + } + } + } + if !sig || !nsec { + t.Errorf("Expected RRSIG and NSEC in answer section") + } +} + +func TestBlackLiesApexDS(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testApexDSMsg() + m.SetQuestion("miek.nl.", dns.TypeDS) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Ns, 2) { + t.Errorf("Authority section should have 2 sigs") + } + var nsec *dns.NSEC + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeNSEC { + nsec = r.(*dns.NSEC) + } + } + if nsec == nil { + t.Error("Expected NSEC, got none") + } else if correctNsecForDS(nsec) { + t.Error("NSEC DS at the apex zone should cover all apex type.") + } +} + +func TestBlackLiesDS(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testApexDSMsg() + m.SetQuestion("sub.miek.nl.", dns.TypeDS) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Ns, 2) { + t.Errorf("Authority section should have 2 sigs") + } + var nsec *dns.NSEC + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeNSEC { + nsec = r.(*dns.NSEC) + } + } + if nsec == nil { + t.Error("Expected NSEC, got none") + } else if !correctNsecForDS(nsec) { + t.Error("NSEC DS should cover delegation type only.") + } +} + +func correctNsecForDS(nsec *dns.NSEC) bool { + var bitmask uint + /* Coherent TypeBitMap for NSEC of DS should contain at least: + * {TypeNS, TypeNSEC, TypeRRSIG} and no SOA. + * Any missing type will confuse resolver because + * it will prove that the dns query cannot be a delegation point, + * which will break trust resolution for unsigned delegated domain. + * No SOA is obvious for none apex query. + */ + for _, typeBitmask := range nsec.TypeBitMap { + switch typeBitmask { + case dns.TypeNS: + bitmask |= 1 + case dns.TypeNSEC: + bitmask |= 2 + case dns.TypeRRSIG: + bitmask |= 4 + case dns.TypeSOA: + return false + } + } + return bitmask == 7 +} + +func testNxdomainMsg() *dns.Msg { + return &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeNameError}, + Question: []dns.Question{{Name: "ww.miek.nl.", Qclass: dns.ClassINET, Qtype: dns.TypeTXT}}, + Ns: []dns.RR{test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1461471181 14400 3600 604800 14400")}, + } +} + +func testSuccessMsg() *dns.Msg { + return &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, + Question: []dns.Question{{Name: "www.miek.nl.", Qclass: dns.ClassINET, Qtype: dns.TypeTXT}}, + Answer: []dns.RR{test.TXT(`www.miek.nl. 1800 IN TXT "response"`)}, + } +} + +func testNsecMsg() *dns.Msg { + return &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeNameError}, + Question: []dns.Question{{Name: "www.miek.nl.", Qclass: dns.ClassINET, Qtype: dns.TypeNSEC}}, + Ns: []dns.RR{test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1461471181 14400 3600 604800 14400")}, + } +} + +func testApexDSMsg() *dns.Msg { + return &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeNameError}, + Question: []dns.Question{{Name: "miek.nl.", Qclass: dns.ClassINET, Qtype: dns.TypeDS}}, + Ns: []dns.RR{test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1461471181 14400 3600 604800 14400")}, + } +} diff --git a/plugin/dnssec/cache.go b/plugin/dnssec/cache.go new file mode 100644 index 0000000..d80f5c1 --- /dev/null +++ b/plugin/dnssec/cache.go @@ -0,0 +1,48 @@ +package dnssec + +import ( + "hash/fnv" + "io" + "time" + + "github.com/coredns/coredns/plugin/pkg/cache" + + "github.com/miekg/dns" +) + +// hash serializes the RRset and returns a signature cache key. +func hash(rrs []dns.RR) uint64 { + h := fnv.New64() + // we need to hash the entire RRset to pick the correct sig, if the rrset + // changes for whatever reason we should resign. + // We could use wirefmt, or the string format, both create garbage when creating + // the hash key. And of course is a uint64 big enough? + for _, rr := range rrs { + io.WriteString(h, rr.String()) + } + return h.Sum64() +} + +func periodicClean(c *cache.Cache, stop <-chan struct{}) { + tick := time.NewTicker(8 * time.Hour) + defer tick.Stop() + for { + select { + case <-tick.C: + // we sign for 8 days, check if a signature in the cache reached 75% of that (i.e. 6), if found delete + // the signature + is75 := time.Now().UTC().Add(twoDays) + c.Walk(func(items map[uint64]interface{}, key uint64) bool { + for _, rr := range items[key].([]dns.RR) { + if !rr.(*dns.RRSIG).ValidityPeriod(is75) { + delete(items, key) + } + } + return true + }) + + case <-stop: + return + } + } +} diff --git a/plugin/dnssec/cache_test.go b/plugin/dnssec/cache_test.go new file mode 100644 index 0000000..8d5ea88 --- /dev/null +++ b/plugin/dnssec/cache_test.go @@ -0,0 +1,82 @@ +package dnssec + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" +) + +func TestCacheSet(t *testing.T) { + fPriv, rmPriv, _ := test.TempFile(".", privKey) + fPub, rmPub, _ := test.TempFile(".", pubKey) + defer rmPriv() + defer rmPub() + + dnskey, err := ParseKeyFile(fPub, fPriv) + if err != nil { + t.Fatalf("Failed to parse key: %v\n", err) + } + + c := cache.New(defaultCap) + m := testMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + k := hash(m.Answer) // calculate *before* we add the sig + d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, nil, c) + d.Sign(state, time.Now().UTC(), server) + + _, ok := d.get(k, server) + if !ok { + t.Errorf("Signature was not added to the cache") + } +} + +func TestCacheNotValidExpired(t *testing.T) { + fPriv, rmPriv, _ := test.TempFile(".", privKey) + fPub, rmPub, _ := test.TempFile(".", pubKey) + defer rmPriv() + defer rmPub() + + dnskey, err := ParseKeyFile(fPub, fPriv) + if err != nil { + t.Fatalf("Failed to parse key: %v\n", err) + } + + c := cache.New(defaultCap) + m := testMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + k := hash(m.Answer) // calculate *before* we add the sig + d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, nil, c) + d.Sign(state, time.Now().UTC().AddDate(0, 0, -9), server) + + _, ok := d.get(k, server) + if ok { + t.Errorf("Signature was added to the cache even though not valid") + } +} + +func TestCacheNotValidYet(t *testing.T) { + fPriv, rmPriv, _ := test.TempFile(".", privKey) + fPub, rmPub, _ := test.TempFile(".", pubKey) + defer rmPriv() + defer rmPub() + + dnskey, err := ParseKeyFile(fPub, fPriv) + if err != nil { + t.Fatalf("Failed to parse key: %v\n", err) + } + + c := cache.New(defaultCap) + m := testMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + k := hash(m.Answer) // calculate *before* we add the sig + d := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, nil, c) + d.Sign(state, time.Now().UTC().AddDate(0, 0, +9), server) + + _, ok := d.get(k, server) + if ok { + t.Errorf("Signature was added to the cache even though not valid yet") + } +} diff --git a/plugin/dnssec/dnskey.go b/plugin/dnssec/dnskey.go new file mode 100644 index 0000000..161db94 --- /dev/null +++ b/plugin/dnssec/dnskey.go @@ -0,0 +1,95 @@ +package dnssec + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "errors" + "os" + "path/filepath" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "golang.org/x/crypto/ed25519" +) + +// DNSKEY holds a DNSSEC public and private key used for on-the-fly signing. +type DNSKEY struct { + K *dns.DNSKEY + D *dns.DS + s crypto.Signer + tag uint16 +} + +// ParseKeyFile read a DNSSEC keyfile as generated by dnssec-keygen or other +// utilities. It adds ".key" for the public key and ".private" for the private key. +func ParseKeyFile(pubFile, privFile string) (*DNSKEY, error) { + f, e := os.Open(filepath.Clean(pubFile)) + if e != nil { + return nil, e + } + defer f.Close() + k, e := dns.ReadRR(f, pubFile) + if e != nil { + return nil, e + } + + f, e = os.Open(filepath.Clean(privFile)) + if e != nil { + return nil, e + } + defer f.Close() + + dk, ok := k.(*dns.DNSKEY) + if !ok { + return nil, errors.New("no public key found") + } + p, e := dk.ReadPrivateKey(f, privFile) + if e != nil { + return nil, e + } + + if s, ok := p.(*rsa.PrivateKey); ok { + return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: dk.KeyTag()}, nil + } + if s, ok := p.(*ecdsa.PrivateKey); ok { + return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: dk.KeyTag()}, nil + } + if s, ok := p.(ed25519.PrivateKey); ok { + return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: s, tag: dk.KeyTag()}, nil + } + return &DNSKEY{K: dk, D: dk.ToDS(dns.SHA256), s: nil, tag: 0}, errors.New("no private key found") +} + +// getDNSKEY returns the correct DNSKEY to the client. Signatures are added when do is true. +func (d Dnssec) getDNSKEY(state request.Request, zone string, do bool, server string) *dns.Msg { + keys := make([]dns.RR, len(d.keys)) + for i, k := range d.keys { + keys[i] = dns.Copy(k.K) + keys[i].Header().Name = zone + } + m := new(dns.Msg) + m.SetReply(state.Req) + m.Answer = keys + if !do { + return m + } + + incep, expir := incepExpir(time.Now().UTC()) + if sigs, err := d.sign(keys, zone, 3600, incep, expir, server); err == nil { + m.Answer = append(m.Answer, sigs...) + } + return m +} + +// Return true if, and only if, this is a zone key with the SEP bit unset. This implies a ZSK (rfc4034 2.1.1). +func (k DNSKEY) isZSK() bool { + return k.K.Flags&(1<<8) == (1<<8) && k.K.Flags&1 == 0 +} + +// Return true if, and only if, this is a zone key with the SEP bit set. This implies a KSK (rfc4034 2.1.1). +func (k DNSKEY) isKSK() bool { + return k.K.Flags&(1<<8) == (1<<8) && k.K.Flags&1 == 1 +} diff --git a/plugin/dnssec/dnssec.go b/plugin/dnssec/dnssec.go new file mode 100644 index 0000000..edda7a8 --- /dev/null +++ b/plugin/dnssec/dnssec.go @@ -0,0 +1,179 @@ +// Package dnssec implements a plugin that signs responses on-the-fly using +// NSEC black lies. +package dnssec + +import ( + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/plugin/pkg/singleflight" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Dnssec signs the reply on-the-fly. +type Dnssec struct { + Next plugin.Handler + + zones []string + keys []*DNSKEY + splitkeys bool + inflight *singleflight.Group + cache *cache.Cache +} + +// New returns a new Dnssec. +func New(zones []string, keys []*DNSKEY, splitkeys bool, next plugin.Handler, c *cache.Cache) Dnssec { + return Dnssec{Next: next, + zones: zones, + keys: keys, + splitkeys: splitkeys, + cache: c, + inflight: new(singleflight.Group), + } +} + +// Sign signs the message in state. it takes care of negative or nodata responses. It +// uses NSEC black lies for authenticated denial of existence. For delegations it +// will insert DS records and sign those. +// Signatures will be cached for a short while. By default we sign for 8 days, +// starting 3 hours ago. +func (d Dnssec) Sign(state request.Request, now time.Time, server string) *dns.Msg { + req := state.Req + + incep, expir := incepExpir(now) + + mt, _ := response.Typify(req, time.Now().UTC()) // TODO(miek): need opt record here? + if mt == response.Delegation { + // We either sign DS or NSEC of DS. + ttl := req.Ns[0].Header().Ttl + + ds := []dns.RR{} + for i := range req.Ns { + if req.Ns[i].Header().Rrtype == dns.TypeDS { + ds = append(ds, req.Ns[i]) + } + } + if len(ds) == 0 { + if sigs, err := d.nsec(state, mt, ttl, incep, expir, server); err == nil { + req.Ns = append(req.Ns, sigs...) + } + } else if sigs, err := d.sign(ds, state.Zone, ttl, incep, expir, server); err == nil { + req.Ns = append(req.Ns, sigs...) + } + return req + } + + if mt == response.NameError || mt == response.NoData { + if req.Ns[0].Header().Rrtype != dns.TypeSOA || len(req.Ns) > 1 { + return req + } + + ttl := req.Ns[0].Header().Ttl + + if sigs, err := d.sign(req.Ns, state.Zone, ttl, incep, expir, server); err == nil { + req.Ns = append(req.Ns, sigs...) + } + if sigs, err := d.nsec(state, mt, ttl, incep, expir, server); err == nil { + req.Ns = append(req.Ns, sigs...) + } + if len(req.Ns) > 1 { // actually added nsec and sigs, reset the rcode + req.Rcode = dns.RcodeSuccess + if state.QType() == dns.TypeNSEC { // If original query was NSEC move Ns to Answer without SOA + req.Answer = req.Ns[len(req.Ns)-2 : len(req.Ns)] + req.Ns = nil + } + } + return req + } + + for _, r := range rrSets(req.Answer) { + ttl := r[0].Header().Ttl + if sigs, err := d.sign(r, state.Zone, ttl, incep, expir, server); err == nil { + req.Answer = append(req.Answer, sigs...) + } + } + for _, r := range rrSets(req.Ns) { + ttl := r[0].Header().Ttl + if sigs, err := d.sign(r, state.Zone, ttl, incep, expir, server); err == nil { + req.Ns = append(req.Ns, sigs...) + } + } + for _, r := range rrSets(req.Extra) { + ttl := r[0].Header().Ttl + if sigs, err := d.sign(r, state.Zone, ttl, incep, expir, server); err == nil { + req.Extra = append(req.Extra, sigs...) + } + } + return req +} + +func (d Dnssec) sign(rrs []dns.RR, signerName string, ttl, incep, expir uint32, server string) ([]dns.RR, error) { + k := hash(rrs) + sgs, ok := d.get(k, server) + if ok { + return sgs, nil + } + + sigs, err := d.inflight.Do(k, func() (interface{}, error) { + var sigs []dns.RR + for _, k := range d.keys { + if d.splitkeys { + if len(rrs) > 0 && rrs[0].Header().Rrtype == dns.TypeDNSKEY { + // We are signing a DNSKEY RRSet. With split keys, we need to use a KSK here. + if !k.isKSK() { + continue + } + } else { + // For non-DNSKEY RRSets, we want to use a ZSK. + if !k.isZSK() { + continue + } + } + } + sig := k.newRRSIG(signerName, ttl, incep, expir) + if e := sig.Sign(k.s, rrs); e != nil { + return sigs, e + } + sigs = append(sigs, sig) + } + d.set(k, sigs) + return sigs, nil + }) + return sigs.([]dns.RR), err +} + +func (d Dnssec) set(key uint64, sigs []dns.RR) { d.cache.Add(key, sigs) } + +func (d Dnssec) get(key uint64, server string) ([]dns.RR, bool) { + if s, ok := d.cache.Get(key); ok { + // we sign for 8 days, check if a signature in the cache reached 3/4 of that + is75 := time.Now().UTC().Add(twoDays) + for _, rr := range s.([]dns.RR) { + if !rr.(*dns.RRSIG).ValidityPeriod(is75) { + cacheMisses.WithLabelValues(server).Inc() + return nil, false + } + } + + cacheHits.WithLabelValues(server).Inc() + return s.([]dns.RR), true + } + cacheMisses.WithLabelValues(server).Inc() + return nil, false +} + +func incepExpir(now time.Time) (uint32, uint32) { + incep := uint32(now.Add(-3 * time.Hour).Unix()) // -(2+1) hours, be sure to catch daylight saving time and such + expir := uint32(now.Add(eightDays).Unix()) // sign for 8 days + return incep, expir +} + +const ( + eightDays = 8 * 24 * time.Hour + twoDays = 2 * 24 * time.Hour + defaultCap = 10000 // default capacity of the cache. +) diff --git a/plugin/dnssec/dnssec_test.go b/plugin/dnssec/dnssec_test.go new file mode 100644 index 0000000..8b55ea3 --- /dev/null +++ b/plugin/dnssec/dnssec_test.go @@ -0,0 +1,286 @@ +package dnssec + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestZoneSigning(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Answer, 1) { + t.Errorf("Answer section should have 1 RRSIG") + } + if !section(m.Ns, 1) { + t.Errorf("Authority section should have 1 RRSIG") + } +} + +func TestZoneSigningDouble(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + fPriv1, rmPriv1, _ := test.TempFile(".", privKey1) + fPub1, rmPub1, _ := test.TempFile(".", pubKey1) + defer rmPriv1() + defer rmPub1() + + key1, err := ParseKeyFile(fPub1, fPriv1) + if err != nil { + t.Fatalf("Failed to parse key: %v\n", err) + } + d.keys = append(d.keys, key1) + + m := testMsg() + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Answer, 2) { + t.Errorf("Answer section should have 1 RRSIG") + } + if !section(m.Ns, 2) { + t.Errorf("Authority section should have 1 RRSIG") + } +} + +// TestSigningDifferentZone tests if a key for miek.nl and be used for example.org. +func TestSigningDifferentZone(t *testing.T) { + fPriv, rmPriv, _ := test.TempFile(".", privKey) + fPub, rmPub, _ := test.TempFile(".", pubKey) + defer rmPriv() + defer rmPub() + + key, err := ParseKeyFile(fPub, fPriv) + if err != nil { + t.Fatalf("Failed to parse key: %v\n", err) + } + + m := testMsgEx() + state := request.Request{Req: m, Zone: "example.org."} + c := cache.New(defaultCap) + d := New([]string{"example.org."}, []*DNSKEY{key}, false, nil, c) + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Answer, 1) { + t.Errorf("Answer section should have 1 RRSIG") + t.Logf("%+v\n", m) + } + if !section(m.Ns, 1) { + t.Errorf("Authority section should have 1 RRSIG") + t.Logf("%+v\n", m) + } +} + +func TestSigningCname(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testMsgCname() + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Answer, 1) { + t.Errorf("Answer section should have 1 RRSIG") + } +} + +func TestSigningDname(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testMsgDname() + state := request.Request{Req: m, Zone: "miek.nl."} + // We sign *everything* we see, also the synthesized CNAME. + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Answer, 3) { + t.Errorf("Answer section should have 3 RRSIGs") + } +} + +func TestSigningEmpty(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testEmptyMsg() + m.SetQuestion("a.miek.nl.", dns.TypeA) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Ns, 2) { + t.Errorf("Authority section should have 2 RRSIGs") + } +} + +func TestDelegationSigned(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testMsgDelegationSigned() + m.SetQuestion("sub.miek.nl.", dns.TypeNS) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Ns, 1) { + t.Errorf("Authority section should have 1 RRSIGs") + } + if !section(m.Extra, 0) { + t.Error("Extra section should not have RRSIGs") + } +} + +func TestDelegationUnSigned(t *testing.T) { + d, rm1, rm2 := newDnssec(t, []string{"miek.nl."}) + defer rm1() + defer rm2() + + m := testMsgDelegationUnSigned() + m.SetQuestion("sub.miek.nl.", dns.TypeNS) + state := request.Request{Req: m, Zone: "miek.nl."} + m = d.Sign(state, time.Now().UTC(), server) + if !section(m.Ns, 1) { + t.Errorf("Authority section should have 1 RRSIG") + } + if !section(m.Extra, 0) { + t.Error("Extra section should not have RRSIG") + } + var nsec *dns.NSEC + var rrsig *dns.RRSIG + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeNSEC { + nsec = r.(*dns.NSEC) + } + if r.Header().Rrtype == dns.TypeRRSIG { + rrsig = r.(*dns.RRSIG) + } + } + if nsec == nil { + t.Error("Authority section should hold a NSEC record") + } + if rrsig.TypeCovered != dns.TypeNSEC { + t.Errorf("RRSIG should cover type %s, got %s", + dns.TypeToString[dns.TypeNSEC], dns.TypeToString[rrsig.TypeCovered]) + } + if !correctNsecForDS(nsec) { + t.Error("NSEC as invalid TypeBitMap for a DS") + } +} + +func section(rss []dns.RR, nrSigs int) bool { + i := 0 + for _, r := range rss { + if r.Header().Rrtype == dns.TypeRRSIG { + i++ + } + } + return nrSigs == i +} + +func testMsg() *dns.Msg { + // don't care about the message header + return &dns.Msg{ + Answer: []dns.RR{test.MX("miek.nl. 1703 IN MX 1 aspmx.l.google.com.")}, + Ns: []dns.RR{test.NS("miek.nl. 1703 IN NS omval.tednet.nl.")}, + } +} +func testMsgEx() *dns.Msg { + return &dns.Msg{ + Answer: []dns.RR{test.MX("example.org. 1703 IN MX 1 aspmx.l.google.com.")}, + Ns: []dns.RR{test.NS("example.org. 1703 IN NS omval.tednet.nl.")}, + } +} + +func testMsgCname() *dns.Msg { + return &dns.Msg{ + Answer: []dns.RR{test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl.")}, + } +} + +func testMsgDname() *dns.Msg { + return &dns.Msg{ + Answer: []dns.RR{ + test.CNAME("a.dname.miek.nl. 1800 IN CNAME a.test.miek.nl."), + test.A("a.test.miek.nl. 1800 IN A 139.162.196.78"), + test.DNAME("dname.miek.nl. 1800 IN DNAME test.miek.nl."), + }, + } +} + +func testMsgDelegationSigned() *dns.Msg { + return &dns.Msg{ + Ns: []dns.RR{ + test.NS("sub.miek.nl. 1800 IN NS ns1.sub.miek.nl."), + test.DS("sub." + dsKey), + }, + Extra: []dns.RR{ + test.A("ns1.sub.miek.nl. 1800 IN A 192.0.2.1"), + }, + } +} + +func testMsgDelegationUnSigned() *dns.Msg { + return &dns.Msg{ + Ns: []dns.RR{ + test.NS("sub.miek.nl. 1800 IN NS ns1.sub.miek.nl."), + }, + Extra: []dns.RR{ + test.A("ns1.sub.miek.nl. 1800 IN A 192.0.2.1"), + }, + } +} + +func testEmptyMsg() *dns.Msg { + // don't care about the message header + return &dns.Msg{ + Ns: []dns.RR{test.SOA("miek.nl. 1800 IN SOA ns.miek.nl. dnsmaster.miek.nl. 2017100301 200 100 604800 3600")}, + } +} + +func newDnssec(t *testing.T, zones []string) (Dnssec, func(), func()) { + k, rm1, rm2 := newKey(t) + c := cache.New(defaultCap) + d := New(zones, []*DNSKEY{k}, false, nil, c) + return d, rm1, rm2 +} + +func newKey(t *testing.T) (*DNSKEY, func(), func()) { + fPriv, rmPriv, _ := test.TempFile(".", privKey) + fPub, rmPub, _ := test.TempFile(".", pubKey) + + key, err := ParseKeyFile(fPub, fPriv) + if err != nil { + t.Fatalf("Failed to parse key: %v\n", err) + } + return key, rmPriv, rmPub +} + +const ( + pubKey = `miek.nl. IN DNSKEY 257 3 13 0J8u0XJ9GNGFEBXuAmLu04taHG4BXPP3gwhetiOUMnGA+x09nqzgF5IY OyjWB7N3rXqQbnOSILhH1hnuyh7mmA==` + privKey = `Private-key-format: v1.3 +Algorithm: 13 (ECDSAP256SHA256) +PrivateKey: /4BZk8AFvyW5hL3cOLSVxIp1RTqHSAEloWUxj86p3gs= +Created: 20160423195532 +Publish: 20160423195532 +Activate: 20160423195532 +` + dsKey = `miek.nl. IN DS 18512 13 2 D4E806322598BC97A003EF1ACDFF352EEFF7B42DBB0D41B8224714C36AEF08D9` + pubKey1 = `example.org. IN DNSKEY 257 3 13 tVRWNSGpHZbCi7Pr7OmbADVUO3MxJ0Lb8Lk3o/HBHqCxf5K/J50lFqRa 98lkdAIiFOVRy8LyMvjwmxZKwB5MNw==` + privKey1 = `Private-key-format: v1.3 +Algorithm: 13 (ECDSAP256SHA256) +PrivateKey: i8j4OfDGT8CQt24SDwLz2hg9yx4qKOEOh1LvbAuSp1c= +Created: 20160423211746 +Publish: 20160423211746 +Activate: 20160423211746 +` +) diff --git a/plugin/dnssec/handler.go b/plugin/dnssec/handler.go new file mode 100644 index 0000000..1ab70ab --- /dev/null +++ b/plugin/dnssec/handler.go @@ -0,0 +1,50 @@ +package dnssec + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServeDNS implements the plugin.Handler interface. +func (d Dnssec) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + do := state.Do() + qname := state.Name() + qtype := state.QType() + zone := plugin.Zones(d.zones).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) + } + + state.Zone = zone + server := metrics.WithServer(ctx) + + // Intercept queries for DNSKEY, but only if one of the zones matches the qname, otherwise we let + // the query through. + if qtype == dns.TypeDNSKEY { + for _, z := range d.zones { + if qname == z { + resp := d.getDNSKEY(state, z, do, server) + resp.Authoritative = true + w.WriteMsg(resp) + return dns.RcodeSuccess, nil + } + } + } + + if do { + drr := &ResponseWriter{w, d, server} + return plugin.NextOrFailure(d.Name(), d.Next, ctx, drr, r) + } + + return plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) +} + +// Name implements the Handler interface. +func (d Dnssec) Name() string { return "dnssec" } diff --git a/plugin/dnssec/handler_test.go b/plugin/dnssec/handler_test.go new file mode 100644 index 0000000..e82e546 --- /dev/null +++ b/plugin/dnssec/handler_test.go @@ -0,0 +1,253 @@ +package dnssec + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/pkg/cache" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var dnssecTestCases = []test.Case{ + { + Qname: "miek.nl.", Qtype: dns.TypeDNSKEY, + Answer: []dns.RR{ + test.DNSKEY("miek.nl. 3600 IN DNSKEY 257 3 13 0J8u0XJ9GNGFEBXuAmLu04taHG4"), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeDNSKEY, Do: true, + Answer: []dns.RR{ + test.DNSKEY("miek.nl. 3600 IN DNSKEY 257 3 13 0J8u0XJ9GNGFEBXuAmLu04taHG4"), + test.RRSIG("miek.nl. 3600 IN RRSIG DNSKEY 13 2 3600 20160503150844 20160425120844 18512 miek.nl. Iw/kNOyM"), + }, + /* Extra: []dns.RR{test.OPT(4096, true)}, this has moved to the server and can't be test here */ + }, +} + +var dnsTestCases = []test.Case{ + { + Qname: "miek.nl.", Qtype: dns.TypeDNSKEY, + Answer: []dns.RR{ + test.DNSKEY("miek.nl. 3600 IN DNSKEY 257 3 13 0J8u0XJ9GNGFEBXuAmLu04taHG4"), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeNS, Do: true, + Answer: []dns.RR{ + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.RRSIG("miek.nl. 1800 IN RRSIG NS 13 2 1800 20220101121212 20220201121212 18512 miek.nl. RandomNotChecked"), + }, + }, + { + Qname: "deleg.miek.nl.", Qtype: dns.TypeNS, Do: true, + Ns: []dns.RR{ + test.DS("deleg.miek.nl. 1800 IN DS 18512 13 2 D4E806322598BC97A003EF1ACDFF352EEFF7B42DBB0D41B8224714C36AEF08D9"), + test.NS("deleg.miek.nl. 1800 IN NS ns01.deleg.miek.nl."), + test.RRSIG("deleg.miek.nl. 1800 IN RRSIG DS 13 3 1800 20220101121212 20220201121212 18512 miek.nl. RandomNotChecked"), + }, + }, + { + Qname: "unsigned.miek.nl.", Qtype: dns.TypeNS, Do: true, + Ns: []dns.RR{ + test.NS("unsigned.miek.nl. 1800 IN NS ns01.deleg.miek.nl."), + test.NSEC("unsigned.miek.nl. 1800 IN NSEC unsigned\\000.miek.nl. NS RRSIG NSEC"), + test.RRSIG("unsigned.miek.nl. 1800 IN RRSIG NSEC 13 3 1800 20220101121212 20220201121212 18512 miek.nl. RandomNotChecked"), + }, + }, + { // DS should not come from dnssec plugin + Qname: "deleg.miek.nl.", Qtype: dns.TypeDS, + Answer: []dns.RR{ + test.DS("deleg.miek.nl. 1800 IN DS 18512 13 2 D4E806322598BC97A003EF1ACDFF352EEFF7B42DBB0D41B8224714C36AEF08D9"), + }, + Ns: []dns.RR{ + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + }, + }, + { + Qname: "unsigned.miek.nl.", Qtype: dns.TypeDS, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeDS, Do: true, + Ns: []dns.RR{ + test.NSEC("miek.nl. 1800 IN NSEC \\000.miek.nl. A HINFO NS SOA MX TXT AAAA LOC SRV CERT SSHFP RRSIG NSEC DNSKEY TLSA HIP OPENPGPKEY SPF"), + test.RRSIG("miek.nl. 1800 IN RRSIG NSEC 13 2 1800 20220101121212 20220201121212 18512 miek.nl. RandomNotChecked"), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 13 2 3600 20171220141741 20171212111741 18512 miek.nl. 8bLTReqmuQtw=="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "deleg.miek.nl.", Qtype: dns.TypeDS, Do: true, + Answer: []dns.RR{ + test.DS("deleg.miek.nl. 1800 IN DS 18512 13 2 D4E806322598BC97A003EF1ACDFF352EEFF7B42DBB0D41B8224714C36AEF08D9"), + test.RRSIG("deleg.miek.nl. 1800 IN RRSIG DS 13 3 1800 20220101121212 20220201121212 18512 miek.nl. RandomNotChecked"), + }, + Ns: []dns.RR{ + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.RRSIG("miek.nl. 1800 IN RRSIG NS 13 2 3600 20161217114912 20161209084912 18512 miek.nl. ad9gA8VWgF1H8ze9/0Rk2Q=="), + }, + }, + { + Qname: "unsigned.miek.nl.", Qtype: dns.TypeDS, Do: true, + Ns: []dns.RR{ + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 13 2 3600 20171220141741 20171212111741 18512 miek.nl. 8bLTReqmuQtw=="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + test.NSEC("unsigned.miek.nl. 1800 IN NSEC \\000.unsigned.miek.nl. NS RRSIG NSEC"), + test.RRSIG("unsigned.miek.nl. 1800 IN RRSIG NSEC 13 3 1800 20220101121212 20220201121212 18512 miek.nl. RandomNotChecked"), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."), + }, + Ns: []dns.RR{ + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeMX, Do: true, + Answer: []dns.RR{ + test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."), + test.RRSIG("miek.nl. 1800 IN RRSIG MX 13 2 3600 20160503192428 20160425162428 18512 miek.nl. 4nxuGKitXjPVA9zP1JIUvA09"), + }, + Ns: []dns.RR{ + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.RRSIG("miek.nl. 1800 IN RRSIG NS 13 2 3600 20161217114912 20161209084912 18512 miek.nl. ad9gA8VWgF1H8ze9/0Rk2Q=="), + }, + }, + { + Qname: "www.miek.nl.", Qtype: dns.TypeAAAA, Do: true, + Answer: []dns.RR{ + test.AAAA("a.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + test.RRSIG("a.miek.nl. 1800 IN RRSIG AAAA 13 3 3600 20160503193047 20160425163047 18512 miek.nl. UAyMG+gcnoXW3"), + test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."), + test.RRSIG("www.miek.nl. 1800 IN RRSIG CNAME 13 3 3600 20160503193047 20160425163047 18512 miek.nl. E3qGZn"), + }, + Ns: []dns.RR{ + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.RRSIG("miek.nl. 1800 IN RRSIG NS 13 2 3600 20161217114912 20161209084912 18512 miek.nl. ad9gA8VWgF1H8ze9/0Rk2Q=="), + }, + }, + { + Qname: "wwwww.miek.nl.", Qtype: dns.TypeAAAA, Do: true, + Ns: []dns.RR{ + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 13 2 3600 20171220135446 20171212105446 18512 miek.nl. hCRzzjYz6w=="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + test.NSEC("wwwww.miek.nl. 1800 IN NSEC \\000.wwwww.miek.nl. A HINFO TXT LOC SRV CERT SSHFP RRSIG NSEC TLSA HIP OPENPGPKEY SPF"), + test.RRSIG("wwwww.miek.nl. 1800 IN RRSIG NSEC 13 3 3600 20171220135446 20171212105446 18512 miek.nl. cVUQWs8xw=="), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeHINFO, Do: true, + Ns: []dns.RR{ + test.NSEC("miek.nl. 1800 IN NSEC \\000.miek.nl. A NS SOA MX TXT AAAA LOC SRV CERT SSHFP RRSIG NSEC DNSKEY TLSA HIP OPENPGPKEY SPF"), + test.RRSIG("miek.nl. 1800 IN RRSIG NSEC 13 2 3600 20171220141741 20171212111741 18512 miek.nl. GuXROL7Uu+UiPcg=="), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 13 2 3600 20171220141741 20171212111741 18512 miek.nl. 8bLTReqmuQtw=="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "www.example.org.", Qtype: dns.TypeAAAA, Do: true, + Rcode: dns.RcodeServerFailure, + }, +} + +func TestLookupZone(t *testing.T) { + zone, err := file.Parse(strings.NewReader(dbMiekNL), "miek.nl.", "stdin", 0) + if err != nil { + return + } + fm := file.File{Next: test.ErrorHandler(), Zones: file.Zones{Z: map[string]*file.Zone{"miek.nl.": zone}, Names: []string{"miek.nl."}}} + dnskey, rm1, rm2 := newKey(t) + defer rm1() + defer rm2() + c := cache.New(defaultCap) + dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, fm, c) + + for _, tc := range dnsTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := dh.ServeDNS(context.TODO(), rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + if err := test.SortAndCheck(rec.Msg, tc); err != nil { + t.Error(err) + } + } +} + +func TestLookupDNSKEY(t *testing.T) { + dnskey, rm1, rm2 := newKey(t) + defer rm1() + defer rm2() + c := cache.New(defaultCap) + dh := New([]string{"miek.nl."}, []*DNSKEY{dnskey}, false, test.ErrorHandler(), c) + + for _, tc := range dnssecTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := dh.ServeDNS(context.TODO(), rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if !resp.Authoritative { + t.Errorf("Authoritative Answer should be true, got false") + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + + // If there is an NSEC present in authority section check if the bitmap does not have the qtype set. + for _, rr := range resp.Ns { + if n, ok := rr.(*dns.NSEC); ok { + for i := range n.TypeBitMap { + if n.TypeBitMap[i] == tc.Qtype { + t.Errorf("Bitmap contains qtype: %d", tc.Qtype) + } + } + } + } + } +} + +const dbMiekNL = ` +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + + IN MX 1 aspmx.l.google.com. + + IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 + +a IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +www IN CNAME a +deleg IN NS ns01.deleg + IN DS 18512 13 2 D4E806322598BC97A003EF1ACDFF352EEFF7B42DBB0D41B8224714C36AEF08D9 +unsigned IN NS ns01.deleg +` diff --git a/plugin/dnssec/log_test.go b/plugin/dnssec/log_test.go new file mode 100644 index 0000000..e8f3a1d --- /dev/null +++ b/plugin/dnssec/log_test.go @@ -0,0 +1,5 @@ +package dnssec + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/dnssec/metrics.go b/plugin/dnssec/metrics.go new file mode 100644 index 0000000..e69dbf5 --- /dev/null +++ b/plugin/dnssec/metrics.go @@ -0,0 +1,32 @@ +package dnssec + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // cacheSize is the number of elements in the dnssec cache. + cacheSize = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnssec", + Name: "cache_entries", + Help: "The number of elements in the dnssec cache.", + }, []string{"server", "type"}) + // cacheHits is the count of cache hits. + cacheHits = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnssec", + Name: "cache_hits_total", + Help: "The count of cache hits.", + }, []string{"server"}) + // cacheMisses is the count of cache misses. + cacheMisses = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "dnssec", + Name: "cache_misses_total", + Help: "The count of cache misses.", + }, []string{"server"}) +) diff --git a/plugin/dnssec/responsewriter.go b/plugin/dnssec/responsewriter.go new file mode 100644 index 0000000..355b317 --- /dev/null +++ b/plugin/dnssec/responsewriter.go @@ -0,0 +1,43 @@ +package dnssec + +import ( + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ResponseWriter signs the response on the fly. +type ResponseWriter struct { + dns.ResponseWriter + d Dnssec + server string // server label for metrics. +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (d *ResponseWriter) WriteMsg(res *dns.Msg) error { + // By definition we should sign anything that comes back, we should still figure out for + // which zone it should be. + state := request.Request{W: d.ResponseWriter, Req: res} + + zone := plugin.Zones(d.d.zones).Matches(state.Name()) + if zone == "" { + return d.ResponseWriter.WriteMsg(res) + } + state.Zone = zone + + res = d.d.Sign(state, time.Now().UTC(), d.server) + cacheSize.WithLabelValues(d.server, "signature").Set(float64(d.d.cache.Len())) + // No need for EDNS0 trickery, as that is handled by the server. + + return d.ResponseWriter.WriteMsg(res) +} + +// Write implements the dns.ResponseWriter interface. +func (d *ResponseWriter) Write(buf []byte) (int, error) { + log.Warning("Dnssec called with Write: not signing reply") + n, err := d.ResponseWriter.Write(buf) + return n, err +} diff --git a/plugin/dnssec/rrsig.go b/plugin/dnssec/rrsig.go new file mode 100644 index 0000000..250a603 --- /dev/null +++ b/plugin/dnssec/rrsig.go @@ -0,0 +1,53 @@ +package dnssec + +import "github.com/miekg/dns" + +// newRRSIG returns a new RRSIG, with all fields filled out, except the signed data. +func (k *DNSKEY) newRRSIG(signerName string, ttl, incep, expir uint32) *dns.RRSIG { + sig := new(dns.RRSIG) + + sig.Hdr.Rrtype = dns.TypeRRSIG + sig.Algorithm = k.K.Algorithm + sig.KeyTag = k.tag + sig.SignerName = signerName + sig.Hdr.Ttl = ttl + sig.OrigTtl = origTTL + + sig.Inception = incep + sig.Expiration = expir + + return sig +} + +type rrset struct { + qname string + qtype uint16 +} + +// rrSets returns rrs as a map of RRsets. It skips RRSIG and OPT records as those don't need to be signed. +func rrSets(rrs []dns.RR) map[rrset][]dns.RR { + m := make(map[rrset][]dns.RR) + + for _, r := range rrs { + if r.Header().Rrtype == dns.TypeRRSIG || r.Header().Rrtype == dns.TypeOPT { + continue + } + + if s, ok := m[rrset{r.Header().Name, r.Header().Rrtype}]; ok { + s = append(s, r) + m[rrset{r.Header().Name, r.Header().Rrtype}] = s + continue + } + + s := make([]dns.RR, 1, 3) + s[0] = r + m[rrset{r.Header().Name, r.Header().Rrtype}] = s + } + + if len(m) > 0 { + return m + } + return nil +} + +const origTTL = 3600 diff --git a/plugin/dnssec/setup.go b/plugin/dnssec/setup.go new file mode 100644 index 0000000..7820e93 --- /dev/null +++ b/plugin/dnssec/setup.go @@ -0,0 +1,146 @@ +package dnssec + +import ( + "fmt" + "path/filepath" + "strconv" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/cache" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("dnssec") + +func init() { plugin.Register("dnssec", setup) } + +func setup(c *caddy.Controller) error { + zones, keys, capacity, splitkeys, err := dnssecParse(c) + if err != nil { + return plugin.Error("dnssec", err) + } + + ca := cache.New(capacity) + stop := make(chan struct{}) + + c.OnShutdown(func() error { + close(stop) + return nil + }) + c.OnStartup(func() error { + go periodicClean(ca, stop) + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return New(zones, keys, splitkeys, next, ca) + }) + + return nil +} + +func dnssecParse(c *caddy.Controller) ([]string, []*DNSKEY, int, bool, error) { + zones := []string{} + keys := []*DNSKEY{} + capacity := defaultCap + + i := 0 + for c.Next() { + if i > 0 { + return nil, nil, 0, false, plugin.ErrOnce + } + i++ + + // dnssec [zones...] + zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + + for c.NextBlock() { + switch x := c.Val(); x { + case "key": + k, e := keyParse(c) + if e != nil { + return nil, nil, 0, false, e + } + keys = append(keys, k...) + case "cache_capacity": + if !c.NextArg() { + return nil, nil, 0, false, c.ArgErr() + } + value := c.Val() + cacheCap, err := strconv.Atoi(value) + if err != nil { + return nil, nil, 0, false, err + } + capacity = cacheCap + default: + return nil, nil, 0, false, c.Errf("unknown property '%s'", x) + } + } + } + // Check if we have both KSKs and ZSKs. + zsk, ksk := 0, 0 + for _, k := range keys { + if k.isKSK() { + ksk++ + } else if k.isZSK() { + zsk++ + } + } + splitkeys := zsk > 0 && ksk > 0 + + // Check if each keys owner name can actually sign the zones we want them to sign. + for _, k := range keys { + kname := plugin.Name(k.K.Header().Name) + ok := false + for i := range zones { + if kname.Matches(zones[i]) { + ok = true + break + } + } + if !ok { + return zones, keys, capacity, splitkeys, fmt.Errorf("key %s (keyid: %d) can not sign any of the zones", string(kname), k.tag) + } + } + + return zones, keys, capacity, splitkeys, nil +} + +func keyParse(c *caddy.Controller) ([]*DNSKEY, error) { + keys := []*DNSKEY{} + config := dnsserver.GetConfig(c) + + if !c.NextArg() { + return nil, c.ArgErr() + } + value := c.Val() + if value == "file" { + ks := c.RemainingArgs() + if len(ks) == 0 { + return nil, c.ArgErr() + } + + for _, k := range ks { + base := k + // Kmiek.nl.+013+26205.key, handle .private or without extension: Kmiek.nl.+013+26205 + if strings.HasSuffix(k, ".key") { + base = k[:len(k)-4] + } + if strings.HasSuffix(k, ".private") { + base = k[:len(k)-8] + } + if !filepath.IsAbs(base) && config.Root != "" { + base = filepath.Join(config.Root, base) + } + k, err := ParseKeyFile(base+".key", base+".private") + if err != nil { + return nil, err + } + keys = append(keys, k) + } + } + return keys, nil +} diff --git a/plugin/dnssec/setup_test.go b/plugin/dnssec/setup_test.go new file mode 100644 index 0000000..66ff45f --- /dev/null +++ b/plugin/dnssec/setup_test.go @@ -0,0 +1,160 @@ +package dnssec + +import ( + "os" + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupDnssec(t *testing.T) { + if err := os.WriteFile("Kcluster.local.key", []byte(keypub), 0644); err != nil { + t.Fatalf("Failed to write pub key file: %s", err) + } + defer func() { os.Remove("Kcluster.local.key") }() + if err := os.WriteFile("Kcluster.local.private", []byte(keypriv), 0644); err != nil { + t.Fatalf("Failed to write private key file: %s", err) + } + defer func() { os.Remove("Kcluster.local.private") }() + if err := os.WriteFile("ksk_Kcluster.local.key", []byte(kskpub), 0644); err != nil { + t.Fatalf("Failed to write pub key file: %s", err) + } + defer func() { os.Remove("ksk_Kcluster.local.key") }() + if err := os.WriteFile("ksk_Kcluster.local.private", []byte(kskpriv), 0644); err != nil { + t.Fatalf("Failed to write private key file: %s", err) + } + defer func() { os.Remove("ksk_Kcluster.local.private") }() + + tests := []struct { + input string + shouldErr bool + expectedZones []string + expectedKeys []string + expectedSplitkeys bool + expectedCapacity int + expectedErrContent string + }{ + {`dnssec`, false, nil, nil, false, defaultCap, ""}, + {`dnssec example.org`, false, []string{"example.org."}, nil, false, defaultCap, ""}, + {`dnssec 10.0.0.0/8`, false, []string{"10.in-addr.arpa."}, nil, false, defaultCap, ""}, + { + `dnssec example.org { + cache_capacity 100 + }`, false, []string{"example.org."}, nil, false, 100, "", + }, + { + `dnssec cluster.local { + key file Kcluster.local + }`, false, []string{"cluster.local."}, nil, false, defaultCap, "", + }, + { + `dnssec example.org cluster.local { + key file Kcluster.local + }`, false, []string{"example.org.", "cluster.local."}, nil, false, defaultCap, "", + }, + // fails + { + `dnssec example.org { + key file Kcluster.local + }`, true, []string{"example.org."}, nil, false, defaultCap, "can not sign any", + }, + { + `dnssec example.org { + key + }`, true, []string{"example.org."}, nil, false, defaultCap, "argument count", + }, + { + `dnssec example.org { + key file + }`, true, []string{"example.org."}, nil, false, defaultCap, "argument count", + }, + {`dnssec + dnssec`, true, nil, nil, false, defaultCap, ""}, + { + `dnssec cluster.local { + key file Kcluster.local + key file ksk_Kcluster.local + }`, false, []string{"cluster.local."}, nil, true, defaultCap, "", + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + zones, keys, capacity, splitkeys, err := dnssecParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + if !test.shouldErr { + for i, z := range test.expectedZones { + if zones[i] != z { + t.Errorf("Dnssec not correctly set for input %s. Expected: %s, actual: %s", test.input, z, zones[i]) + } + } + for i, k := range test.expectedKeys { + if k != keys[i].K.Header().Name { + t.Errorf("Dnssec not correctly set for input %s. Expected: '%s', actual: '%s'", test.input, k, keys[i].K.Header().Name) + } + } + if splitkeys != test.expectedSplitkeys { + t.Errorf("Detected split keys does not match. Expected: %t, actual %t", test.expectedSplitkeys, splitkeys) + } + if capacity != test.expectedCapacity { + t.Errorf("Dnssec not correctly set capacity for input '%s' Expected: '%d', actual: '%d'", test.input, capacity, test.expectedCapacity) + } + } + } +} + +const keypub = `; This is a zone-signing key, keyid 45330, for cluster.local. +; Created: 20170901060531 (Fri Sep 1 08:05:31 2017) +; Publish: 20170901060531 (Fri Sep 1 08:05:31 2017) +; Activate: 20170901060531 (Fri Sep 1 08:05:31 2017) +cluster.local. IN DNSKEY 256 3 5 AwEAAcFpDv+Cb23kFJowu+VU++b2N1uEHi6Ll9H0BzLasFOdJjEEclCO q/KlD4682vOMXxJNN8ZwOyiCa7Y0TEYqSwWvhHyn3bHCwuy4I6fss4Wd 7Y9dU+6QTgJ8LimGG40Iizjc9zqoU8Q+q81vIukpYWOHioHoY7hsWBvS RSlzDJk3` + +const keypriv = `Private-key-format: v1.3 +Algorithm: 5 (RSASHA1) +Modulus: wWkO/4JvbeQUmjC75VT75vY3W4QeLouX0fQHMtqwU50mMQRyUI6r8qUPjrza84xfEk03xnA7KIJrtjRMRipLBa+EfKfdscLC7Lgjp+yzhZ3tj11T7pBOAnwuKYYbjQiLONz3OqhTxD6rzW8i6SlhY4eKgehjuGxYG9JFKXMMmTc= +PublicExponent: AQAB +PrivateExponent: K5XyZFBPrjMVFX5gCZlyPyVDamNGrfSVXSIiMSqpS96BSdCXtmHAjCj4bZFPwkzi6+vs4tJN8p4ZifEVM0a6qwPZyENBrc2qbsweOXE6l8BaPVWFX30xvVRzGXuNtXxlBXE17zoHty5r5mRyRou1bc2HUS5otdkEjE30RiocQVk= +Prime1: 7RRFUxaZkVNVH1DaT/SV5Sb8kABB389qLwU++argeDCVf+Wm9BBlTrsz2U6bKlfpaUmYZKtCCd+CVxqzMyuu0w== +Prime2: 0NiY3d7Fa08IGY9L4TaFc02A721YcDNBBf95BP31qGvwnYsLFM/1xZwaEsIjohg8g+m/GpyIlvNMbK6pywIVjQ== +Exponent1: XjXO8pype9mMmvwrNNix9DTQ6nxfsQugW30PMHGZ78kGr6NX++bEC0xS50jYWjRDGcbYGzD+9iNujSScD3qNZw== +Exponent2: wkoOhLIfhUIj7etikyUup2Ld5WAbW15DSrotstg0NrgcQ+Q7reP96BXeJ79WeREFE09cyvv/EjdLzPv81/CbbQ== +Coefficient: ah4LL0KLTO8kSKHK+X9Ud8grYi94QSNdbX11ge/eFcS/41QhDuZRTAFv4y0+IG+VWd+XzojLsQs+jzLe5GzINg== +Created: 20170901060531 +Publish: 20170901060531 +Activate: 20170901060531 +` + +const kskpub = `; This is a zone-signing key, keyid 45330, for cluster.local. +; Created: 20170901060531 (Fri Sep 1 08:05:31 2017) +; Publish: 20170901060531 (Fri Sep 1 08:05:31 2017) +; Activate: 20170901060531 (Fri Sep 1 08:05:31 2017) +cluster.local. IN DNSKEY 257 3 5 AwEAAcFpDv+Cb23kFJowu+VU++b2N1uEHi6Ll9H0BzLasFOdJjEEclCO q/KlD4682vOMXxJNN8ZwOyiCa7Y0TEYqSwWvhHyn3bHCwuy4I6fss4Wd 7Y9dU+6QTgJ8LimGG40Iizjc9zqoU8Q+q81vIukpYWOHioHoY7hsWBvS RSlzDJk3` + +const kskpriv = `Private-key-format: v1.3 +Algorithm: 5 (RSASHA1) +Modulus: wWkO/4JvbeQUmjC75VT75vY3W4QeLouX0fQHMtqwU50mMQRyUI6r8qUPjrza84xfEk03xnA7KIJrtjRMRipLBa+EfKfdscLC7Lgjp+yzhZ3tj11T7pBOAnwuKYYbjQiLONz3OqhTxD6rzW8i6SlhY4eKgehjuGxYG9JFKXMMmTc= +PublicExponent: AQAB +PrivateExponent: K5XyZFBPrjMVFX5gCZlyPyVDamNGrfSVXSIiMSqpS96BSdCXtmHAjCj4bZFPwkzi6+vs4tJN8p4ZifEVM0a6qwPZyENBrc2qbsweOXE6l8BaPVWFX30xvVRzGXuNtXxlBXE17zoHty5r5mRyRou1bc2HUS5otdkEjE30RiocQVk= +Prime1: 7RRFUxaZkVNVH1DaT/SV5Sb8kABB389qLwU++argeDCVf+Wm9BBlTrsz2U6bKlfpaUmYZKtCCd+CVxqzMyuu0w== +Prime2: 0NiY3d7Fa08IGY9L4TaFc02A721YcDNBBf95BP31qGvwnYsLFM/1xZwaEsIjohg8g+m/GpyIlvNMbK6pywIVjQ== +Exponent1: XjXO8pype9mMmvwrNNix9DTQ6nxfsQugW30PMHGZ78kGr6NX++bEC0xS50jYWjRDGcbYGzD+9iNujSScD3qNZw== +Exponent2: wkoOhLIfhUIj7etikyUup2Ld5WAbW15DSrotstg0NrgcQ+Q7reP96BXeJ79WeREFE09cyvv/EjdLzPv81/CbbQ== +Coefficient: ah4LL0KLTO8kSKHK+X9Ud8grYi94QSNdbX11ge/eFcS/41QhDuZRTAFv4y0+IG+VWd+XzojLsQs+jzLe5GzINg== +Created: 20170901060531 +Publish: 20170901060531 +Activate: 20170901060531 +` diff --git a/plugin/dnstap/README.md b/plugin/dnstap/README.md new file mode 100644 index 0000000..b90c45f --- /dev/null +++ b/plugin/dnstap/README.md @@ -0,0 +1,169 @@ +# dnstap + +## Name + +*dnstap* - enables logging to dnstap. + +## Description + +dnstap is a flexible, structured binary log format for DNS software; see https://dnstap.info. With this +plugin you make CoreDNS output dnstap logging. + +Every message is sent to the socket as soon as it comes in, the *dnstap* plugin has a buffer of +10000 messages, above that number dnstap messages will be dropped (this is logged). + +## Syntax + +~~~ txt +dnstap SOCKET [full] { + [identity IDENTITY] + [version VERSION] + [extra EXTRA] + [skipverify] +} +~~~ + +* **SOCKET** is the socket (path) supplied to the dnstap command line tool. +* `full` to include the wire-format DNS message. +* **IDENTITY** to override the identity of the server. Defaults to the hostname. +* **VERSION** to override the version field. Defaults to the CoreDNS version. +* **EXTRA** to define "extra" field in dnstap payload, [metadata](../metadata/) replacement available here. +* `skipverify` to skip tls verification during connection. Default to be secure + +## Examples + +Log information about client requests and responses to */tmp/dnstap.sock*. + +~~~ txt +dnstap /tmp/dnstap.sock +~~~ + +Log information including the wire-format DNS message about client requests and responses to */tmp/dnstap.sock*. + +~~~ txt +dnstap unix:///tmp/dnstap.sock full +~~~ + +Log to a remote endpoint. + +~~~ txt +dnstap tcp://127.0.0.1:6000 full +~~~ + +Log to a remote endpoint by FQDN. + +~~~ txt +dnstap tcp://example.com:6000 full +~~~ + +Log to a socket, overriding the default identity and version. + +~~~ txt +dnstap /tmp/dnstap.sock { + identity my-dns-server1 + version MyDNSServer-1.2.3 +} +~~~ + +Log to a socket, customize the "extra" field in dnstap payload. You may use metadata provided by other plugins in the extra field. + +~~~ txt +forward . 8.8.8.8 +metadata +dnstap /tmp/dnstap.sock { + extra "upstream: {/forward/upstream}" +} +~~~ + +Log to a remote TLS endpoint. + +~~~ txt +dnstap tls://127.0.0.1:6000 full { + skipverify +} +~~~ + +You can use _dnstap_ more than once to define multiple taps. The following logs information including the +wire-format DNS message about client requests and responses to */tmp/dnstap.sock*, +and also sends client requests and responses without wire-format DNS messages to a remote FQDN. + +~~~ txt +dnstap /tmp/dnstap.sock full +dnstap tcp://example.com:6000 +~~~ + +## Command Line Tool + +Dnstap has a command line tool that can be used to inspect the logging. The tool can be found +at Github: <https://github.com/dnstap/golang-dnstap>. It's written in Go. + +The following command listens on the given socket and decodes messages to stdout. + +~~~ sh +$ dnstap -u /tmp/dnstap.sock +~~~ + +The following command listens on the given socket and saves message payloads to a binary dnstap-format log file. + +~~~ sh +$ dnstap -u /tmp/dnstap.sock -w /tmp/test.dnstap +~~~ + +Listen for dnstap messages on port 6000. + +~~~ sh +$ dnstap -l 127.0.0.1:6000 +~~~ + +## Using Dnstap in your plugin + +In your setup function, collect and store a list of all *dnstap* plugins loaded in the config: + +~~~ go +x := &ExamplePlugin{} + +c.OnStartup(func() error { + if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil { + for tapPlugin, ok := taph.(*dnstap.Dnstap); ok; tapPlugin, ok = tapPlugin.Next.(*dnstap.Dnstap) { + x.tapPlugins = append(x.tapPlugins, tapPlugin) + } + } + return nil +}) +~~~ + +And then in your plugin: + +~~~ go +import ( + "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/request" + + tap "github.com/dnstap/golang-dnstap" +) + +func (x ExamplePlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + for _, tapPlugin := range x.tapPlugins { + q := new(msg.Msg) + msg.SetQueryTime(q, time.Now()) + msg.SetQueryAddress(q, w.RemoteAddr()) + if tapPlugin.IncludeRawMessage { + buf, _ := r.Pack() // r has been seen packed/unpacked before, this should not fail + q.QueryMessage = buf + } + msg.SetType(q, tap.Message_CLIENT_QUERY) + + // if no metadata interpretation is needed, just send the message + tapPlugin.TapMessage(q) + + // OR: to interpret the metadata in "extra" field, give more context info + tapPlugin.TapMessageWithMetadata(ctx, q, request.Request{W: w, Req: query}) + } + // ... +} +~~~ + +## See Also + +The website [dnstap.info](https://dnstap.info) has info on the dnstap protocol. The *forward* +plugin's `dnstap.go` uses dnstap to tap messages sent to an upstream. diff --git a/plugin/dnstap/encoder.go b/plugin/dnstap/encoder.go new file mode 100644 index 0000000..93d3e73 --- /dev/null +++ b/plugin/dnstap/encoder.go @@ -0,0 +1,40 @@ +package dnstap + +import ( + "io" + "time" + + tap "github.com/dnstap/golang-dnstap" + fs "github.com/farsightsec/golang-framestream" + "google.golang.org/protobuf/proto" +) + +// encoder wraps a golang-framestream.Encoder. +type encoder struct { + fs *fs.Encoder +} + +func newEncoder(w io.Writer, timeout time.Duration) (*encoder, error) { + fs, err := fs.NewEncoder(w, &fs.EncoderOptions{ + ContentType: []byte("protobuf:dnstap.Dnstap"), + Bidirectional: true, + Timeout: timeout, + }) + if err != nil { + return nil, err + } + return &encoder{fs}, nil +} + +func (e *encoder) writeMsg(msg *tap.Dnstap) error { + buf, err := proto.Marshal(msg) + if err != nil { + return err + } + + _, err = e.fs.Write(buf) // n < len(buf) should return an error? + return err +} + +func (e *encoder) flush() error { return e.fs.Flush() } +func (e *encoder) close() error { return e.fs.Close() } diff --git a/plugin/dnstap/handler.go b/plugin/dnstap/handler.go new file mode 100644 index 0000000..59dbaba --- /dev/null +++ b/plugin/dnstap/handler.go @@ -0,0 +1,85 @@ +package dnstap + +import ( + "context" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/plugin/pkg/replacer" + "github.com/coredns/coredns/request" + + tap "github.com/dnstap/golang-dnstap" + "github.com/miekg/dns" +) + +// Dnstap is the dnstap handler. +type Dnstap struct { + Next plugin.Handler + io tapper + repl replacer.Replacer + + // IncludeRawMessage will include the raw DNS message into the dnstap messages if true. + IncludeRawMessage bool + Identity []byte + Version []byte + ExtraFormat string +} + +// TapMessage sends the message m to the dnstap interface, without populating "Extra" field. +func (h Dnstap) TapMessage(m *tap.Message) { + if h.ExtraFormat == "" { + h.tapWithExtra(m, nil) + } else { + h.tapWithExtra(m, []byte(h.ExtraFormat)) + } +} + +// TapMessageWithMetadata sends the message m to the dnstap interface, with "Extra" field being populated. +func (h Dnstap) TapMessageWithMetadata(ctx context.Context, m *tap.Message, state request.Request) { + if h.ExtraFormat == "" { + h.tapWithExtra(m, nil) + return + } + extraStr := h.repl.Replace(ctx, state, nil, h.ExtraFormat) + h.tapWithExtra(m, []byte(extraStr)) +} + +func (h Dnstap) tapWithExtra(m *tap.Message, extra []byte) { + t := tap.Dnstap_MESSAGE + h.io.Dnstap(&tap.Dnstap{Type: &t, Message: m, Identity: h.Identity, Version: h.Version, Extra: extra}) +} + +func (h Dnstap) tapQuery(ctx context.Context, w dns.ResponseWriter, query *dns.Msg, queryTime time.Time) { + q := new(tap.Message) + msg.SetQueryTime(q, queryTime) + msg.SetQueryAddress(q, w.RemoteAddr()) + + if h.IncludeRawMessage { + buf, _ := query.Pack() + q.QueryMessage = buf + } + msg.SetType(q, tap.Message_CLIENT_QUERY) + state := request.Request{W: w, Req: query} + h.TapMessageWithMetadata(ctx, q, state) +} + +// ServeDNS logs the client query and response to dnstap and passes the dnstap Context. +func (h Dnstap) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rw := &ResponseWriter{ + ResponseWriter: w, + Dnstap: h, + query: r, + ctx: ctx, + queryTime: time.Now(), + } + + // The query tap message should be sent before sending the query to the + // forwarder. Otherwise, the tap messages will come out out of order. + h.tapQuery(ctx, w, r, rw.queryTime) + + return plugin.NextOrFailure(h.Name(), h.Next, ctx, rw, r) +} + +// Name implements the plugin.Plugin interface. +func (h Dnstap) Name() string { return "dnstap" } diff --git a/plugin/dnstap/handler_test.go b/plugin/dnstap/handler_test.go new file mode 100644 index 0000000..cb492ac --- /dev/null +++ b/plugin/dnstap/handler_test.go @@ -0,0 +1,136 @@ +package dnstap + +import ( + "context" + "net" + "testing" + + "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/plugin/metadata" + test "github.com/coredns/coredns/plugin/test" + + tap "github.com/dnstap/golang-dnstap" + "github.com/miekg/dns" +) + +func testCase(t *testing.T, tapq, tapr *tap.Dnstap, q, r *dns.Msg, extraFormat string) { + w := writer{t: t} + w.queue = append(w.queue, tapq, tapr) + h := Dnstap{ + Next: test.HandlerFunc(func(_ context.Context, + w dns.ResponseWriter, _ *dns.Msg) (int, error) { + return 0, w.WriteMsg(r) + }), + io: &w, + ExtraFormat: extraFormat, + } + ctx := metadata.ContextWithMetadata(context.TODO()) + ok := metadata.SetValueFunc(ctx, "metadata/test", func() string { + return "MetadataValue" + }) + if !ok { + t.Fatal("Failed to set metadata") + } + _, err := h.ServeDNS(ctx, &test.ResponseWriter{}, q) + if err != nil { + t.Fatal(err) + } +} + +type writer struct { + t *testing.T + queue []*tap.Dnstap +} + +func (w *writer) Dnstap(e *tap.Dnstap) { + if len(w.queue) == 0 { + w.t.Error("Message not expected") + } + + ex := w.queue[0].Message + got := e.Message + + if string(ex.QueryAddress) != string(got.QueryAddress) { + w.t.Errorf("Expected source address %s, got %s", ex.QueryAddress, got.QueryAddress) + } + if string(ex.ResponseAddress) != string(got.ResponseAddress) { + w.t.Errorf("Expected response address %s, got %s", ex.ResponseAddress, got.ResponseAddress) + } + if *ex.QueryPort != *got.QueryPort { + w.t.Errorf("Expected port %d, got %d", *ex.QueryPort, *got.QueryPort) + } + if *ex.SocketFamily != *got.SocketFamily { + w.t.Errorf("Expected socket family %d, got %d", *ex.SocketFamily, *got.SocketFamily) + } + if string(w.queue[0].Extra) != string(e.Extra) { + w.t.Errorf("Expected extra %s, got %s", w.queue[0].Extra, e.Extra) + } + w.queue = w.queue[1:] +} + +func TestDnstap(t *testing.T) { + q := test.Case{Qname: "example.org", Qtype: dns.TypeA}.Msg() + r := test.Case{ + Qname: "example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("example.org. 3600 IN A 10.0.0.1"), + }, + }.Msg() + + tapq := &tap.Dnstap{ + Message: testMessage(), + } + msg.SetType(tapq.Message, tap.Message_CLIENT_QUERY) + tapr := &tap.Dnstap{ + Message: testMessage(), + } + msg.SetType(tapr.Message, tap.Message_CLIENT_RESPONSE) + testCase(t, tapq, tapr, q, r, "") + + tapq_with_extra := &tap.Dnstap{ + Message: testMessage(), // leave type unset for deepEqual + Extra: []byte("extra_field_MetadataValue_A_example.org._IN_udp_29_10.240.0.1_40212_127.0.0.1"), + } + msg.SetType(tapq_with_extra.Message, tap.Message_CLIENT_QUERY) + tapr_with_extra := &tap.Dnstap{ + Message: testMessage(), + Extra: []byte("extra_field_MetadataValue_A_example.org._IN_udp_29_10.240.0.1_40212_127.0.0.1"), + } + msg.SetType(tapr_with_extra.Message, tap.Message_CLIENT_RESPONSE) + extraFormat := "extra_field_{/metadata/test}_{type}_{name}_{class}_{proto}_{size}_{remote}_{port}_{local}" + testCase(t, tapq_with_extra, tapr_with_extra, q, r, extraFormat) +} + +func testMessage() *tap.Message { + inet := tap.SocketFamily_INET + udp := tap.SocketProtocol_UDP + port := uint32(40212) + return &tap.Message{ + SocketFamily: &inet, + SocketProtocol: &udp, + QueryAddress: net.ParseIP("10.240.0.1"), + QueryPort: &port, + } +} + +func TestTapMessage(t *testing.T) { + extraFormat := "extra_field_no_replacement_{/metadata/test}_{type}_{name}_{class}_{proto}_{size}_{remote}_{port}_{local}" + tapq := &tap.Dnstap{ + Message: testMessage(), + // extra field would not be replaced, since TapMessage won't pass context + Extra: []byte(extraFormat), + } + msg.SetType(tapq.Message, tap.Message_CLIENT_QUERY) + + w := writer{t: t} + w.queue = append(w.queue, tapq) + h := Dnstap{ + Next: test.HandlerFunc(func(_ context.Context, + w dns.ResponseWriter, r *dns.Msg) (int, error) { + return 0, w.WriteMsg(r) + }), + io: &w, + ExtraFormat: extraFormat, + } + h.TapMessage(tapq.Message) +} diff --git a/plugin/dnstap/io.go b/plugin/dnstap/io.go new file mode 100644 index 0000000..f95e4b5 --- /dev/null +++ b/plugin/dnstap/io.go @@ -0,0 +1,143 @@ +package dnstap + +import ( + "crypto/tls" + "net" + "sync/atomic" + "time" + + tap "github.com/dnstap/golang-dnstap" +) + +const ( + tcpWriteBufSize = 1024 * 1024 // there is no good explanation for why this number has this value. + queueSize = 10000 // idem. + + tcpTimeout = 4 * time.Second + flushTimeout = 1 * time.Second + + skipVerify = false // by default, every tls connection is verified to be secure +) + +// tapper interface is used in testing to mock the Dnstap method. +type tapper interface { + Dnstap(*tap.Dnstap) +} + +// dio implements the Tapper interface. +type dio struct { + endpoint string + proto string + enc *encoder + queue chan *tap.Dnstap + dropped uint32 + quit chan struct{} + flushTimeout time.Duration + tcpTimeout time.Duration + skipVerify bool +} + +// newIO returns a new and initialized pointer to a dio. +func newIO(proto, endpoint string) *dio { + return &dio{ + endpoint: endpoint, + proto: proto, + queue: make(chan *tap.Dnstap, queueSize), + quit: make(chan struct{}), + flushTimeout: flushTimeout, + tcpTimeout: tcpTimeout, + skipVerify: skipVerify, + } +} + +func (d *dio) dial() error { + var conn net.Conn + var err error + + if d.proto == "tls" { + config := &tls.Config{ + InsecureSkipVerify: d.skipVerify, + } + dialer := &net.Dialer{ + Timeout: d.tcpTimeout, + } + conn, err = tls.DialWithDialer(dialer, "tcp", d.endpoint, config) + if err != nil { + return err + } + } else { + conn, err = net.DialTimeout(d.proto, d.endpoint, d.tcpTimeout) + if err != nil { + return err + } + } + + if tcpConn, ok := conn.(*net.TCPConn); ok { + tcpConn.SetWriteBuffer(tcpWriteBufSize) + tcpConn.SetNoDelay(false) + } + + d.enc, err = newEncoder(conn, d.tcpTimeout) + return err +} + +// Connect connects to the dnstap endpoint. +func (d *dio) connect() error { + err := d.dial() + go d.serve() + return err +} + +// Dnstap enqueues the payload for log. +func (d *dio) Dnstap(payload *tap.Dnstap) { + select { + case d.queue <- payload: + default: + atomic.AddUint32(&d.dropped, 1) + } +} + +// close waits until the I/O routine is finished to return. +func (d *dio) close() { close(d.quit) } + +func (d *dio) write(payload *tap.Dnstap) error { + if d.enc == nil { + atomic.AddUint32(&d.dropped, 1) + return nil + } + if err := d.enc.writeMsg(payload); err != nil { + atomic.AddUint32(&d.dropped, 1) + return err + } + return nil +} + +func (d *dio) serve() { + timeout := time.NewTimer(d.flushTimeout) + defer timeout.Stop() + for { + timeout.Reset(d.flushTimeout) + select { + case <-d.quit: + if d.enc == nil { + return + } + d.enc.flush() + d.enc.close() + return + case payload := <-d.queue: + if err := d.write(payload); err != nil { + d.dial() + } + case <-timeout.C: + if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 { + log.Warningf("Dropped dnstap messages: %d", dropped) + } + if d.enc == nil { + d.dial() + } else { + d.enc.flush() + } + } + } +} diff --git a/plugin/dnstap/io_test.go b/plugin/dnstap/io_test.go new file mode 100644 index 0000000..3e94f05 --- /dev/null +++ b/plugin/dnstap/io_test.go @@ -0,0 +1,155 @@ +package dnstap + +import ( + "net" + "sync" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/reuseport" + + tap "github.com/dnstap/golang-dnstap" + fs "github.com/farsightsec/golang-framestream" +) + +var ( + msgType = tap.Dnstap_MESSAGE + tmsg = tap.Dnstap{Type: &msgType} +) + +func accept(t *testing.T, l net.Listener, count int) { + server, err := l.Accept() + if err != nil { + t.Fatalf("Server accepted: %s", err) + } + dec, err := fs.NewDecoder(server, &fs.DecoderOptions{ + ContentType: []byte("protobuf:dnstap.Dnstap"), + Bidirectional: true, + }) + if err != nil { + t.Fatalf("Server decoder: %s", err) + } + + for i := 0; i < count; i++ { + if _, err := dec.Decode(); err != nil { + t.Errorf("Server decode: %s", err) + } + } + + if err := server.Close(); err != nil { + t.Error(err) + } +} + +func TestTransport(t *testing.T) { + transport := [2][2]string{ + {"tcp", ":0"}, + {"unix", "dnstap.sock"}, + } + + for _, param := range transport { + l, err := reuseport.Listen(param[0], param[1]) + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + accept(t, l, 1) + wg.Done() + }() + + dio := newIO(param[0], l.Addr().String()) + dio.tcpTimeout = 10 * time.Millisecond + dio.flushTimeout = 30 * time.Millisecond + dio.connect() + + dio.Dnstap(&tmsg) + + wg.Wait() + l.Close() + dio.close() + } +} + +func TestRace(t *testing.T) { + count := 10 + + l, err := reuseport.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + defer l.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + accept(t, l, count) + wg.Done() + }() + + dio := newIO("tcp", l.Addr().String()) + dio.tcpTimeout = 10 * time.Millisecond + dio.flushTimeout = 30 * time.Millisecond + dio.connect() + defer dio.close() + + wg.Add(count) + for i := 0; i < count; i++ { + go func() { + tmsg := tap.Dnstap_MESSAGE + dio.Dnstap(&tap.Dnstap{Type: &tmsg}) + wg.Done() + }() + } + wg.Wait() +} + +func TestReconnect(t *testing.T) { + count := 5 + + l, err := reuseport.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + accept(t, l, 1) + wg.Done() + }() + + addr := l.Addr().String() + dio := newIO("tcp", addr) + dio.tcpTimeout = 10 * time.Millisecond + dio.flushTimeout = 30 * time.Millisecond + dio.connect() + defer dio.close() + + dio.Dnstap(&tmsg) + + wg.Wait() + + // Close listener + l.Close() + // And start TCP listener again on the same port + l, err = reuseport.Listen("tcp", addr) + if err != nil { + t.Fatalf("Cannot start listener: %s", err) + } + defer l.Close() + + wg.Add(1) + go func() { + accept(t, l, 1) + wg.Done() + }() + + for i := 0; i < count; i++ { + time.Sleep(100 * time.Millisecond) + dio.Dnstap(&tmsg) + } + wg.Wait() +} diff --git a/plugin/dnstap/log_test.go b/plugin/dnstap/log_test.go new file mode 100644 index 0000000..145aa1d --- /dev/null +++ b/plugin/dnstap/log_test.go @@ -0,0 +1,5 @@ +package dnstap + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/dnstap/msg/msg.go b/plugin/dnstap/msg/msg.go new file mode 100644 index 0000000..f9d84c4 --- /dev/null +++ b/plugin/dnstap/msg/msg.go @@ -0,0 +1,97 @@ +package msg + +import ( + "fmt" + "net" + "time" + + tap "github.com/dnstap/golang-dnstap" +) + +var ( + protoUDP = tap.SocketProtocol_UDP + protoTCP = tap.SocketProtocol_TCP + familyINET = tap.SocketFamily_INET + familyINET6 = tap.SocketFamily_INET6 +) + +// SetQueryAddress adds the query address to the message. This also sets the SocketFamily and SocketProtocol. +func SetQueryAddress(t *tap.Message, addr net.Addr) error { + t.SocketFamily = &familyINET + switch a := addr.(type) { + case *net.TCPAddr: + t.SocketProtocol = &protoTCP + t.QueryAddress = a.IP + + p := uint32(a.Port) + t.QueryPort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + case *net.UDPAddr: + t.SocketProtocol = &protoUDP + t.QueryAddress = a.IP + + p := uint32(a.Port) + t.QueryPort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + default: + return fmt.Errorf("unknown address type: %T", a) + } +} + +// SetResponseAddress the response address to the message. This also sets the SocketFamily and SocketProtocol. +func SetResponseAddress(t *tap.Message, addr net.Addr) error { + t.SocketFamily = &familyINET + switch a := addr.(type) { + case *net.TCPAddr: + t.SocketProtocol = &protoTCP + t.ResponseAddress = a.IP + + p := uint32(a.Port) + t.ResponsePort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + case *net.UDPAddr: + t.SocketProtocol = &protoUDP + t.ResponseAddress = a.IP + + p := uint32(a.Port) + t.ResponsePort = &p + + if a.IP.To4() == nil { + t.SocketFamily = &familyINET6 + } + return nil + default: + return fmt.Errorf("unknown address type: %T", a) + } +} + +// SetQueryTime sets the time of the query in t. +func SetQueryTime(t *tap.Message, ti time.Time) { + qts := uint64(ti.Unix()) + qtn := uint32(ti.Nanosecond()) + t.QueryTimeSec = &qts + t.QueryTimeNsec = &qtn +} + +// SetResponseTime sets the time of the response in t. +func SetResponseTime(t *tap.Message, ti time.Time) { + rts := uint64(ti.Unix()) + rtn := uint32(ti.Nanosecond()) + t.ResponseTimeSec = &rts + t.ResponseTimeNsec = &rtn +} + +// SetType sets the type in t. +func SetType(t *tap.Message, typ tap.Message_Type) { t.Type = &typ } diff --git a/plugin/dnstap/setup.go b/plugin/dnstap/setup.go new file mode 100644 index 0000000..0186f4d --- /dev/null +++ b/plugin/dnstap/setup.go @@ -0,0 +1,140 @@ +package dnstap + +import ( + "net/url" + "os" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/replacer" +) + +var log = clog.NewWithPlugin("dnstap") + +func init() { plugin.Register("dnstap", setup) } + +func parseConfig(c *caddy.Controller) ([]*Dnstap, error) { + dnstaps := []*Dnstap{} + + for c.Next() { // directive name + d := Dnstap{} + endpoint := "" + d.repl = replacer.New() + + args := c.RemainingArgs() + + if len(args) == 0 { + return nil, c.ArgErr() + } + + endpoint = args[0] + + var dio *dio + if strings.HasPrefix(endpoint, "tls://") { + // remote network endpoint + endpointURL, err := url.Parse(endpoint) + if err != nil { + return nil, c.ArgErr() + } + dio = newIO("tls", endpointURL.Host) + d = Dnstap{io: dio} + } else if strings.HasPrefix(endpoint, "tcp://") { + // remote network endpoint + endpointURL, err := url.Parse(endpoint) + if err != nil { + return nil, c.ArgErr() + } + dio = newIO("tcp", endpointURL.Host) + d = Dnstap{io: dio} + } else { + endpoint = strings.TrimPrefix(endpoint, "unix://") + dio = newIO("unix", endpoint) + d = Dnstap{io: dio} + } + + d.IncludeRawMessage = len(args) == 2 && args[1] == "full" + + hostname, _ := os.Hostname() + d.Identity = []byte(hostname) + d.Version = []byte(caddy.AppName + "-" + caddy.AppVersion) + + for c.NextBlock() { + switch c.Val() { + case "skipverify": + { + dio.skipVerify = true + } + case "identity": + { + if !c.NextArg() { + return nil, c.ArgErr() + } + d.Identity = []byte(c.Val()) + } + case "version": + { + if !c.NextArg() { + return nil, c.ArgErr() + } + d.Version = []byte(c.Val()) + } + case "extra": + { + if !c.NextArg() { + return nil, c.ArgErr() + } + d.ExtraFormat = c.Val() + } + } + } + dnstaps = append(dnstaps, &d) + } + return dnstaps, nil +} + +func setup(c *caddy.Controller) error { + dnstaps, err := parseConfig(c) + if err != nil { + return plugin.Error("dnstap", err) + } + + for i := range dnstaps { + dnstap := dnstaps[i] + c.OnStartup(func() error { + if err := dnstap.io.(*dio).connect(); err != nil { + log.Errorf("No connection to dnstap endpoint: %s", err) + } + return nil + }) + + c.OnRestart(func() error { + dnstap.io.(*dio).close() + return nil + }) + + c.OnFinalShutdown(func() error { + dnstap.io.(*dio).close() + return nil + }) + + if i == len(dnstaps)-1 { + // last dnstap plugin in block: point next to next plugin + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + dnstap.Next = next + return dnstap + }) + } else { + // not last dnstap plugin in block: point next to next dnstap + nextDnstap := dnstaps[i+1] + dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler { + dnstap.Next = nextDnstap + return dnstap + }) + } + } + + return nil +} diff --git a/plugin/dnstap/setup_test.go b/plugin/dnstap/setup_test.go new file mode 100644 index 0000000..8365963 --- /dev/null +++ b/plugin/dnstap/setup_test.go @@ -0,0 +1,137 @@ +package dnstap + +import ( + "os" + "reflect" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +type results struct { + endpoint string + full bool + proto string + identity []byte + version []byte + extraFormat string +} + +func TestConfig(t *testing.T) { + hostname, _ := os.Hostname() + tests := []struct { + in string + fail bool + expect []results + }{ + {"dnstap dnstap.sock full", false, []results{{"dnstap.sock", true, "unix", []byte(hostname), []byte("-"), ""}}}, + {"dnstap unix://dnstap.sock", false, []results{{"dnstap.sock", false, "unix", []byte(hostname), []byte("-"), ""}}}, + {"dnstap tcp://127.0.0.1:6000", false, []results{{"127.0.0.1:6000", false, "tcp", []byte(hostname), []byte("-"), ""}}}, + {"dnstap tcp://[::1]:6000", false, []results{{"[::1]:6000", false, "tcp", []byte(hostname), []byte("-"), ""}}}, + {"dnstap tcp://example.com:6000", false, []results{{"example.com:6000", false, "tcp", []byte(hostname), []byte("-"), ""}}}, + {"dnstap", true, []results{{"fail", false, "tcp", []byte(hostname), []byte("-"), ""}}}, + {"dnstap dnstap.sock full {\nidentity NAME\nversion VER\n}\n", false, []results{{"dnstap.sock", true, "unix", []byte("NAME"), []byte("VER"), ""}}}, + {"dnstap dnstap.sock full {\nidentity NAME\nversion VER\nextra EXTRA\n}\n", false, []results{{"dnstap.sock", true, "unix", []byte("NAME"), []byte("VER"), "EXTRA"}}}, + {"dnstap dnstap.sock {\nidentity NAME\nversion VER\nextra EXTRA\n}\n", false, []results{{"dnstap.sock", false, "unix", []byte("NAME"), []byte("VER"), "EXTRA"}}}, + {"dnstap {\nidentity NAME\nversion VER\nextra EXTRA\n}\n", true, []results{{"fail", false, "tcp", []byte("NAME"), []byte("VER"), "EXTRA"}}}, + {`dnstap dnstap.sock full { + identity NAME + version VER + extra EXTRA + } + dnstap tcp://127.0.0.1:6000 { + identity NAME2 + version VER2 + extra EXTRA2 + }`, false, []results{ + {"dnstap.sock", true, "unix", []byte("NAME"), []byte("VER"), "EXTRA"}, + {"127.0.0.1:6000", false, "tcp", []byte("NAME2"), []byte("VER2"), "EXTRA2"}, + }}, + {"dnstap tls://127.0.0.1:6000", false, []results{{"127.0.0.1:6000", false, "tls", []byte(hostname), []byte("-"), ""}}}, + {"dnstap dnstap.sock {\nidentity\n}\n", true, []results{{"dnstap.sock", false, "unix", []byte(hostname), []byte("-"), ""}}}, + {"dnstap dnstap.sock {\nversion\n}\n", true, []results{{"dnstap.sock", false, "unix", []byte(hostname), []byte("-"), ""}}}, + {"dnstap dnstap.sock {\nextra\n}\n", true, []results{{"dnstap.sock", false, "unix", []byte(hostname), []byte("-"), ""}}}, + } + for i, tc := range tests { + c := caddy.NewTestController("dns", tc.in) + taps, err := parseConfig(c) + if tc.fail && err == nil { + t.Fatalf("Test %d: expected test to fail: %s: %s", i, tc.in, err) + } + if tc.fail { + continue + } + + if err != nil { + t.Fatalf("Test %d: expected no error, got %s", i, err) + } + for i, tap := range taps { + if x := tap.io.(*dio).endpoint; x != tc.expect[i].endpoint { + t.Errorf("Test %d: expected endpoint %s, got %s", i, tc.expect[i].endpoint, x) + } + if x := tap.io.(*dio).proto; x != tc.expect[i].proto { + t.Errorf("Test %d: expected proto %s, got %s", i, tc.expect[i].proto, x) + } + if x := tap.IncludeRawMessage; x != tc.expect[i].full { + t.Errorf("Test %d: expected IncludeRawMessage %t, got %t", i, tc.expect[i].full, x) + } + if x := string(tap.Identity); x != string(tc.expect[i].identity) { + t.Errorf("Test %d: expected identity %s, got %s", i, tc.expect[i].identity, x) + } + if x := string(tap.Version); x != string(tc.expect[i].version) { + t.Errorf("Test %d: expected version %s, got %s", i, tc.expect[i].version, x) + } + if x := tap.ExtraFormat; x != tc.expect[i].extraFormat { + t.Errorf("Test %d: expected extra format %s, got %s", i, tc.expect[i].extraFormat, x) + } + } + } +} + +func TestMultiDnstap(t *testing.T) { + input := ` + dnstap dnstap1.sock + dnstap dnstap2.sock + dnstap dnstap3.sock + ` + + c := caddy.NewTestController("dns", input) + setup(c) + dnsserver.NewServer("", []*dnsserver.Config{dnsserver.GetConfig(c)}) + + handlers := dnsserver.GetConfig(c).Handlers() + d1, ok := handlers[0].(*Dnstap) + if !ok { + t.Fatalf("expected first plugin to be Dnstap, got %v", reflect.TypeOf(d1.Next)) + } + + if d1.io.(*dio).endpoint != "dnstap1.sock" { + t.Errorf("expected first dnstap to \"dnstap1.sock\", got %q", d1.io.(*dio).endpoint) + } + if d1.Next == nil { + t.Fatal("expected first dnstap to point to next dnstap instance") + } + + d2, ok := d1.Next.(*Dnstap) + if !ok { + t.Fatalf("expected second plugin to be Dnstap, got %v", reflect.TypeOf(d1.Next)) + } + if d2.io.(*dio).endpoint != "dnstap2.sock" { + t.Errorf("expected second dnstap to \"dnstap2.sock\", got %q", d2.io.(*dio).endpoint) + } + if d2.Next == nil { + t.Fatal("expected second dnstap to point to third dnstap instance") + } + + d3, ok := d2.Next.(*Dnstap) + if !ok { + t.Fatalf("expected third plugin to be Dnstap, got %v", reflect.TypeOf(d2.Next)) + } + if d3.io.(*dio).endpoint != "dnstap3.sock" { + t.Errorf("expected third dnstap to \"dnstap3.sock\", got %q", d3.io.(*dio).endpoint) + } + if d3.Next != nil { + t.Error("expected third plugin to be last, but Next is not nil") + } +} diff --git a/plugin/dnstap/writer.go b/plugin/dnstap/writer.go new file mode 100644 index 0000000..afd19ea --- /dev/null +++ b/plugin/dnstap/writer.go @@ -0,0 +1,44 @@ +package dnstap + +import ( + "context" + "time" + + "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/request" + + tap "github.com/dnstap/golang-dnstap" + "github.com/miekg/dns" +) + +// ResponseWriter captures the client response and logs the query to dnstap. +type ResponseWriter struct { + queryTime time.Time + query *dns.Msg + ctx context.Context + dns.ResponseWriter + Dnstap +} + +// WriteMsg writes back the response to the client and THEN works on logging the request and response to dnstap. +func (w *ResponseWriter) WriteMsg(resp *dns.Msg) error { + err := w.ResponseWriter.WriteMsg(resp) + if err != nil { + return err + } + + r := new(tap.Message) + msg.SetQueryTime(r, w.queryTime) + msg.SetResponseTime(r, time.Now()) + msg.SetQueryAddress(r, w.RemoteAddr()) + + if w.IncludeRawMessage { + buf, _ := resp.Pack() + r.ResponseMessage = buf + } + + msg.SetType(r, tap.Message_CLIENT_RESPONSE) + state := request.Request{W: w.ResponseWriter, Req: w.query} + w.TapMessageWithMetadata(w.ctx, r, state) + return nil +} diff --git a/plugin/done.go b/plugin/done.go new file mode 100644 index 0000000..c6ff863 --- /dev/null +++ b/plugin/done.go @@ -0,0 +1,13 @@ +package plugin + +import "context" + +// Done is a non-blocking function that returns true if the context has been canceled. +func Done(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} diff --git a/plugin/erratic/README.md b/plugin/erratic/README.md new file mode 100644 index 0000000..5e2b06b --- /dev/null +++ b/plugin/erratic/README.md @@ -0,0 +1,89 @@ +# erratic + +## Name + +*erratic* - a plugin useful for testing client behavior. + +## Description + +*erratic* returns a static response to all queries, but the responses can be delayed, +dropped or truncated. The *erratic* plugin will respond to every A or AAAA query. For +any other type it will return a SERVFAIL response (except AXFR). The reply for A will return +192.0.2.53 ([RFC 5737](https://tools.ietf.org/html/rfc5737)), for AAAA it returns 2001:DB8::53 ([RFC +3849](https://tools.ietf.org/html/rfc3849)). For an AXFR request it will respond with a small +zone transfer. + +## Syntax + +~~~ txt +erratic { + drop [AMOUNT] + truncate [AMOUNT] + delay [AMOUNT [DURATION]] +} +~~~ + +* `drop`: drop 1 per **AMOUNT** of queries, the default is 2. +* `truncate`: truncate 1 per **AMOUNT** of queries, the default is 2. +* `delay`: delay 1 per **AMOUNT** of queries for **DURATION**, the default for **AMOUNT** is 2 and + the default for **DURATION** is 100ms. + +In case of a zone transfer and truncate the final SOA record *isn't* added to the response. + +## Ready + +This plugin reports readiness to the ready plugin. + +## Examples + +~~~ corefile +example.org { + erratic { + drop 3 + } +} +~~~ + +Or even shorter if the defaults suit you. Note this only drops queries, it does not delay them. + +~~~ corefile +example.org { + erratic +} +~~~ + +Delay 1 in 3 queries for 50ms + +~~~ corefile +example.org { + erratic { + delay 3 50ms + } +} +~~~ + +Delay 1 in 3 and truncate 1 in 5. + +~~~ corefile +example.org { + erratic { + delay 3 5ms + truncate 5 + } +} +~~~ + +Drop every second query. + +~~~ corefile +example.org { + erratic { + drop 2 + truncate 2 + } +} +~~~ + +## See Also + +[RFC 3849](https://tools.ietf.org/html/rfc3849) and [RFC 5737](https://tools.ietf.org/html/rfc5737). diff --git a/plugin/erratic/autopath.go b/plugin/erratic/autopath.go new file mode 100644 index 0000000..0e29fff --- /dev/null +++ b/plugin/erratic/autopath.go @@ -0,0 +1,8 @@ +package erratic + +import "github.com/coredns/coredns/request" + +// AutoPath implements the AutoPathFunc call from the autopath plugin. +func (e *Erratic) AutoPath(state request.Request) []string { + return []string{"a.example.org.", "b.example.org.", ""} +} diff --git a/plugin/erratic/erratic.go b/plugin/erratic/erratic.go new file mode 100644 index 0000000..da7f68a --- /dev/null +++ b/plugin/erratic/erratic.go @@ -0,0 +1,109 @@ +// Package erratic implements a plugin that returns erratic answers (delayed, dropped). +package erratic + +import ( + "context" + "sync/atomic" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Erratic is a plugin that returns erratic responses to each client. +type Erratic struct { + q uint64 // counter of queries + drop uint64 + delay uint64 + truncate uint64 + + duration time.Duration + large bool // undocumented feature; return large responses for A request (>512B, to test compression). +} + +// ServeDNS implements the plugin.Handler interface. +func (e *Erratic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + drop := false + delay := false + trunc := false + + queryNr := atomic.LoadUint64(&e.q) + atomic.AddUint64(&e.q, 1) + + if e.drop > 0 && queryNr%e.drop == 0 { + drop = true + } + if e.delay > 0 && queryNr%e.delay == 0 { + delay = true + } + if e.truncate > 0 && queryNr&e.truncate == 0 { + trunc = true + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + if trunc { + m.Truncated = true + } + + // small dance to copy rrA or rrAAAA into a non-pointer var that allows us to overwrite the ownername + // in a non-racy way. + switch state.QType() { + case dns.TypeA: + rr := *(rrA.(*dns.A)) + rr.Header().Name = state.QName() + m.Answer = append(m.Answer, &rr) + if e.large { + for i := 0; i < 29; i++ { + m.Answer = append(m.Answer, &rr) + } + } + case dns.TypeAAAA: + rr := *(rrAAAA.(*dns.AAAA)) + rr.Header().Name = state.QName() + m.Answer = append(m.Answer, &rr) + case dns.TypeAXFR: + if drop { + return 0, nil + } + if delay { + time.Sleep(e.duration) + } + + xfr(state, trunc) + return 0, nil + + default: + if drop { + return 0, nil + } + if delay { + time.Sleep(e.duration) + } + // coredns will return error. + return dns.RcodeServerFailure, nil + } + + if drop { + return 0, nil + } + + if delay { + time.Sleep(e.duration) + } + + w.WriteMsg(m) + + return 0, nil +} + +// Name implements the Handler interface. +func (e *Erratic) Name() string { return "erratic" } + +var ( + rrA, _ = dns.NewRR(". IN 0 A 192.0.2.53") + rrAAAA, _ = dns.NewRR(". IN 0 AAAA 2001:DB8::53") +) diff --git a/plugin/erratic/erratic_test.go b/plugin/erratic/erratic_test.go new file mode 100644 index 0000000..de8dbe4 --- /dev/null +++ b/plugin/erratic/erratic_test.go @@ -0,0 +1,116 @@ +package erratic + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestErraticDrop(t *testing.T) { + e := &Erratic{drop: 2} // 50% drops + + tests := []struct { + rrtype uint16 + expectedCode int + expectedErr error + drop bool + }{ + {rrtype: dns.TypeA, expectedCode: dns.RcodeSuccess, expectedErr: nil, drop: true}, + {rrtype: dns.TypeA, expectedCode: dns.RcodeSuccess, expectedErr: nil, drop: false}, + {rrtype: dns.TypeAAAA, expectedCode: dns.RcodeSuccess, expectedErr: nil, drop: true}, + {rrtype: dns.TypeHINFO, expectedCode: dns.RcodeServerFailure, expectedErr: nil, drop: false}, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + req.SetQuestion("example.org.", tc.rrtype) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := e.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %q, but got %q", i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + + if tc.drop && rec.Msg != nil { + t.Errorf("Test %d: Expected dropped message, but got %q", i, rec.Msg.Question[0].Name) + } + } +} + +func TestErraticTruncate(t *testing.T) { + e := &Erratic{truncate: 2} // 50% drops + + tests := []struct { + expectedCode int + expectedErr error + truncate bool + }{ + {expectedCode: dns.RcodeSuccess, expectedErr: nil, truncate: true}, + {expectedCode: dns.RcodeSuccess, expectedErr: nil, truncate: false}, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := e.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %q, but got %q", i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + + if tc.truncate && !rec.Msg.Truncated { + t.Errorf("Test %d: Expected truncated message, but got %q", i, rec.Msg.Question[0].Name) + } + } +} + +func TestAxfr(t *testing.T) { + e := &Erratic{truncate: 0} // nothing, just check if we can get an axfr + + ctx := context.TODO() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeAXFR) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := e.ServeDNS(ctx, rec, req) + if err != nil { + t.Errorf("Failed to set up AXFR: %s", err) + } + if x := rec.Msg.Answer[0].Header().Rrtype; x != dns.TypeSOA { + t.Errorf("Expected for record to be %d, got %d", dns.TypeSOA, x) + } +} + +func TestErratic(t *testing.T) { + e := &Erratic{drop: 0, delay: 0} + + ctx := context.TODO() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + e.ServeDNS(ctx, rec, req) + + if rec.Msg.Answer[0].Header().Rrtype != dns.TypeA { + t.Errorf("Expected A response, got %d type", rec.Msg.Answer[0].Header().Rrtype) + } +} diff --git a/plugin/erratic/log_test.go b/plugin/erratic/log_test.go new file mode 100644 index 0000000..f6fb4bf --- /dev/null +++ b/plugin/erratic/log_test.go @@ -0,0 +1,5 @@ +package erratic + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/erratic/ready.go b/plugin/erratic/ready.go new file mode 100644 index 0000000..d5f18a6 --- /dev/null +++ b/plugin/erratic/ready.go @@ -0,0 +1,13 @@ +package erratic + +import "sync/atomic" + +// Ready returns true if the number of received queries is in the range [3, 5). All other values return false. +// To aid in testing we want to this flip between ready and not ready. +func (e *Erratic) Ready() bool { + q := atomic.LoadUint64(&e.q) + if q >= 3 && q < 5 { + return true + } + return false +} diff --git a/plugin/erratic/setup.go b/plugin/erratic/setup.go new file mode 100644 index 0000000..524473c --- /dev/null +++ b/plugin/erratic/setup.go @@ -0,0 +1,113 @@ +package erratic + +import ( + "fmt" + "strconv" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("erratic", setup) } + +func setup(c *caddy.Controller) error { + e, err := parseErratic(c) + if err != nil { + return plugin.Error("erratic", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return e + }) + + return nil +} + +func parseErratic(c *caddy.Controller) (*Erratic, error) { + e := &Erratic{drop: 2} + drop := false // true if we've seen the drop keyword + + for c.Next() { // 'erratic' + for c.NextBlock() { + switch c.Val() { + case "drop": + args := c.RemainingArgs() + if len(args) > 1 { + return nil, c.ArgErr() + } + + if len(args) == 0 { + continue + } + + amount, err := strconv.ParseInt(args[0], 10, 32) + if err != nil { + return nil, err + } + if amount < 0 { + return nil, fmt.Errorf("illegal amount value given %q", args[0]) + } + e.drop = uint64(amount) + drop = true + case "delay": + args := c.RemainingArgs() + if len(args) > 2 { + return nil, c.ArgErr() + } + + // Defaults. + e.delay = 2 + e.duration = 100 * time.Millisecond + if len(args) == 0 { + continue + } + + amount, err := strconv.ParseInt(args[0], 10, 32) + if err != nil { + return nil, err + } + if amount < 0 { + return nil, fmt.Errorf("illegal amount value given %q", args[0]) + } + e.delay = uint64(amount) + + if len(args) > 1 { + duration, err := time.ParseDuration(args[1]) + if err != nil { + return nil, err + } + e.duration = duration + } + case "truncate": + args := c.RemainingArgs() + if len(args) > 1 { + return nil, c.ArgErr() + } + + if len(args) == 0 { + continue + } + + amount, err := strconv.ParseInt(args[0], 10, 32) + if err != nil { + return nil, err + } + if amount < 0 { + return nil, fmt.Errorf("illegal amount value given %q", args[0]) + } + e.truncate = uint64(amount) + case "large": + e.large = true + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + } + if (e.delay > 0 || e.truncate > 0) && !drop { // delay is set, but we've haven't seen a drop keyword, remove default drop stuff + e.drop = 0 + } + + return e, nil +} diff --git a/plugin/erratic/setup_test.go b/plugin/erratic/setup_test.go new file mode 100644 index 0000000..9d2ff51 --- /dev/null +++ b/plugin/erratic/setup_test.go @@ -0,0 +1,103 @@ +package erratic + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", `erratic { + drop + }`) + if err := setup(c); err != nil { + t.Fatalf("Test 1, expected no errors, but got: %q", err) + } + + c = caddy.NewTestController("dns", `erratic`) + if err := setup(c); err != nil { + t.Fatalf("Test 2, expected no errors, but got: %q", err) + } + + c = caddy.NewTestController("dns", `erratic { + drop -1 + }`) + if err := setup(c); err == nil { + t.Fatalf("Test 4, expected errors, but got: %q", err) + } +} + +func TestParseErratic(t *testing.T) { + tests := []struct { + input string + shouldErr bool + drop uint64 + delay uint64 + truncate uint64 + }{ + // oks + {`erratic`, false, 2, 0, 0}, + {`erratic { + drop 2 + delay 3 1ms + + }`, false, 2, 3, 0}, + {`erratic { + truncate 2 + delay 3 1ms + + }`, false, 0, 3, 2}, + {`erraric { + drop 3 + delay + }`, false, 3, 2, 0}, + // fails + {`erratic { + drop -1 + }`, true, 0, 0, 0}, + {`erratic { + delay -1 + }`, true, 0, 0, 0}, + {`erratic { + delay 1 2 4 + }`, true, 0, 0, 0}, + {`erratic { + delay 15.a + }`, true, 0, 0, 0}, + {`erraric { + drop 3 + delay 3 bla + }`, true, 0, 0, 0}, + {`erraric { + truncate 15.a + }`, true, 0, 0, 0}, + {`erraric { + something-else + }`, true, 0, 0, 0}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + e, err := parseErratic(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + + if test.shouldErr { + continue + } + + if test.delay != e.delay { + t.Errorf("Test %v: Expected delay %d but found: %d", i, test.delay, e.delay) + } + if test.drop != e.drop { + t.Errorf("Test %v: Expected drop %d but found: %d", i, test.drop, e.drop) + } + if test.truncate != e.truncate { + t.Errorf("Test %v: Expected truncate %d but found: %d", i, test.truncate, e.truncate) + } + } +} diff --git a/plugin/erratic/xfr.go b/plugin/erratic/xfr.go new file mode 100644 index 0000000..e1ec77e --- /dev/null +++ b/plugin/erratic/xfr.go @@ -0,0 +1,57 @@ +package erratic + +import ( + "strings" + "sync" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// allRecords returns a small zone file. The first RR must be a SOA. +func allRecords(name string) []dns.RR { + var rrs = []dns.RR{ + test.SOA("xx. 0 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2018050825 7200 3600 1209600 3600"), + test.NS("xx. 0 IN NS b.xx."), + test.NS("xx. 0 IN NS a.xx."), + test.AAAA("a.xx. 0 IN AAAA 2001:bd8::53"), + test.AAAA("b.xx. 0 IN AAAA 2001:500::54"), + } + + for _, r := range rrs { + r.Header().Name = strings.Replace(r.Header().Name, "xx.", name, 1) + + if n, ok := r.(*dns.NS); ok { + n.Ns = strings.Replace(n.Ns, "xx.", name, 1) + } + } + return rrs +} + +func xfr(state request.Request, truncate bool) { + rrs := allRecords(state.QName()) + + ch := make(chan *dns.Envelope) + tr := new(dns.Transfer) + + go func() { + // So the rrs we have don't have a closing SOA, only add that when truncate is false, + // so we send an incomplete AXFR. + if !truncate { + rrs = append(rrs, rrs[0]) + } + + ch <- &dns.Envelope{RR: rrs} + close(ch) + }() + + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + tr.Out(state.W, state.Req, ch) + wg.Done() + }() + wg.Wait() +} diff --git a/plugin/errors/README.md b/plugin/errors/README.md new file mode 100644 index 0000000..27ba105 --- /dev/null +++ b/plugin/errors/README.md @@ -0,0 +1,65 @@ +# errors + +## Name + +*errors* - enables error logging. + +## Description + +Any errors encountered during the query processing will be printed to standard output. The errors of particular type can be consolidated and printed once per some period of time. + +This plugin can only be used once per Server Block. + +## Syntax + +The basic syntax is: + +~~~ +errors +~~~ + +Extra knobs are available with an expanded syntax: + +~~~ +errors { + stacktrace + consolidate DURATION REGEXP [LEVEL] +} +~~~ + +Option `stacktrace` will log a stacktrace during panic recovery. + +Option `consolidate` allows collecting several error messages matching the regular expression **REGEXP** during **DURATION**. After the **DURATION** since receiving the first such message, the consolidated message will be printed to standard output with +log level, which is configurable by optional option **LEVEL**. Supported options for **LEVEL** option are `warning`,`error`,`info` and `debug`. +~~~ +2 errors like '^read udp .* i/o timeout$' occurred in last 30s +~~~ + +Multiple `consolidate` options with different **DURATION** and **REGEXP** are allowed. In case if some error message corresponds to several defined regular expressions the message will be associated with the first appropriate **REGEXP**. + +For better performance, it's recommended to use the `^` or `$` metacharacters in regular expression when filtering error messages by prefix or suffix, e.g. `^failed to .*`, or `.* timeout$`. + +## Examples + +Use the *whoami* to respond to queries in the example.org domain and Log errors to standard output. + +~~~ corefile +example.org { + whoami + errors +} +~~~ + +Use the *forward* plugin to resolve queries via 8.8.8.8 and print consolidated messages +for errors with suffix " i/o timeout" as warnings, +and errors with prefix "Failed to " as errors. + +~~~ corefile +. { + forward . 8.8.8.8 + errors { + consolidate 5m ".* i/o timeout$" warning + consolidate 30s "^Failed to .+" + } +} +~~~ diff --git a/plugin/errors/benchmark_test.go b/plugin/errors/benchmark_test.go new file mode 100644 index 0000000..04e6433 --- /dev/null +++ b/plugin/errors/benchmark_test.go @@ -0,0 +1,27 @@ +package errors + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func BenchmarkServeDNS(b *testing.B) { + h := &errorHandler{} + h.Next = test.ErrorHandler() + + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeA) + w := &test.ResponseWriter{} + ctx := context.TODO() + + for i := 0; i < b.N; i++ { + _, err := h.ServeDNS(ctx, w, r) + if err != nil { + b.Errorf("ServeDNS returned error: %s", err) + } + } +} diff --git a/plugin/errors/errors.go b/plugin/errors/errors.go new file mode 100644 index 0000000..c045f69 --- /dev/null +++ b/plugin/errors/errors.go @@ -0,0 +1,104 @@ +// Package errors implements an error handling plugin. +package errors + +import ( + "context" + "regexp" + "sync/atomic" + "time" + "unsafe" + + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("errors") + +type pattern struct { + ptimer unsafe.Pointer + count uint32 + period time.Duration + pattern *regexp.Regexp + logCallback func(format string, v ...interface{}) +} + +func (p *pattern) timer() *time.Timer { + return (*time.Timer)(atomic.LoadPointer(&p.ptimer)) +} + +func (p *pattern) setTimer(t *time.Timer) { + atomic.StorePointer(&p.ptimer, unsafe.Pointer(t)) +} + +// errorHandler handles DNS errors (and errors from other plugin). +type errorHandler struct { + patterns []*pattern + stopFlag uint32 + Next plugin.Handler +} + +func newErrorHandler() *errorHandler { + return &errorHandler{} +} + +func (h *errorHandler) logPattern(i int) { + cnt := atomic.SwapUint32(&h.patterns[i].count, 0) + if cnt > 0 { + h.patterns[i].logCallback("%d errors like '%s' occurred in last %s", + cnt, h.patterns[i].pattern.String(), h.patterns[i].period) + } +} + +func (h *errorHandler) inc(i int) bool { + if atomic.LoadUint32(&h.stopFlag) > 0 { + return false + } + if atomic.AddUint32(&h.patterns[i].count, 1) == 1 { + ind := i + t := time.AfterFunc(h.patterns[ind].period, func() { + h.logPattern(ind) + }) + h.patterns[ind].setTimer(t) + if atomic.LoadUint32(&h.stopFlag) > 0 && t.Stop() { + h.logPattern(ind) + } + } + return true +} + +func (h *errorHandler) stop() { + atomic.StoreUint32(&h.stopFlag, 1) + for i := range h.patterns { + t := h.patterns[i].timer() + if t != nil && t.Stop() { + h.logPattern(i) + } + } +} + +// ServeDNS implements the plugin.Handler interface. +func (h *errorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rcode, err := plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + + if err != nil { + strErr := err.Error() + for i := range h.patterns { + if h.patterns[i].pattern.MatchString(strErr) { + if h.inc(i) { + return rcode, err + } + break + } + } + state := request.Request{W: w, Req: r} + log.Errorf("%d %s %s: %s", rcode, state.Name(), state.Type(), strErr) + } + + return rcode, err +} + +// Name implements the plugin.Handler interface. +func (h *errorHandler) Name() string { return "errors" } diff --git a/plugin/errors/errors_test.go b/plugin/errors/errors_test.go new file mode 100644 index 0000000..1cd42b4 --- /dev/null +++ b/plugin/errors/errors_test.go @@ -0,0 +1,237 @@ +package errors + +import ( + "bytes" + "context" + "errors" + "fmt" + golog "log" + "regexp" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestErrors(t *testing.T) { + buf := bytes.Buffer{} + golog.SetOutput(&buf) + em := errorHandler{} + + testErr := errors.New("test error") + tests := []struct { + next plugin.Handler + expectedCode int + expectedLog string + expectedErr error + }{ + { + next: genErrorHandler(dns.RcodeSuccess, nil), + expectedCode: dns.RcodeSuccess, + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(dns.RcodeNotAuth, testErr), + expectedCode: dns.RcodeNotAuth, + expectedLog: fmt.Sprintf("%d %s: %v\n", dns.RcodeNotAuth, "example.org. A", testErr), + expectedErr: testErr, + }, + } + + ctx := context.TODO() + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + for i, tc := range tests { + em.Next = tc.next + buf.Reset() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := em.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", + i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", + i, tc.expectedCode, code) + } + if log := buf.String(); !strings.Contains(log, tc.expectedLog) { + t.Errorf("Test %d: Expected log %q, but got %q", + i, tc.expectedLog, log) + } + } +} + +func TestLogPattern(t *testing.T) { + type args struct { + logCallback func(format string, v ...interface{}) + } + tests := []struct { + name string + args args + want string + }{ + { + name: "error log", + args: args{logCallback: log.Errorf}, + want: "[ERROR] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", + }, + { + name: "warn log", + args: args{logCallback: log.Warningf}, + want: "[WARNING] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", + }, + { + name: "info log", + args: args{logCallback: log.Infof}, + want: "[INFO] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", + }, + { + name: "debug log", + args: args{logCallback: log.Debugf}, + want: "[DEBUG] plugin/errors: 4 errors like '^error.*!$' occurred in last 2s", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := bytes.Buffer{} + clog.D.Set() + golog.SetOutput(&buf) + + h := &errorHandler{ + patterns: []*pattern{{ + count: 4, + period: 2 * time.Second, + pattern: regexp.MustCompile("^error.*!$"), + logCallback: tt.args.logCallback, + }}, + } + h.logPattern(0) + + if log := buf.String(); !strings.Contains(log, tt.want) { + t.Errorf("Expected log %q, but got %q", tt.want, log) + } + }) + } +} + +func TestInc(t *testing.T) { + h := &errorHandler{ + stopFlag: 1, + patterns: []*pattern{{ + period: 2 * time.Second, + pattern: regexp.MustCompile("^error.*!$"), + }}, + } + + ret := h.inc(0) + if ret { + t.Error("Unexpected return value, expected false, actual true") + } + + h.stopFlag = 0 + ret = h.inc(0) + if !ret { + t.Error("Unexpected return value, expected true, actual false") + } + + expCnt := uint32(1) + actCnt := atomic.LoadUint32(&h.patterns[0].count) + if actCnt != expCnt { + t.Errorf("Unexpected 'count', expected %d, actual %d", expCnt, actCnt) + } + + t1 := h.patterns[0].timer() + if t1 == nil { + t.Error("Unexpected 'timer', expected not nil") + } + + ret = h.inc(0) + if !ret { + t.Error("Unexpected return value, expected true, actual false") + } + + expCnt = uint32(2) + actCnt = atomic.LoadUint32(&h.patterns[0].count) + if actCnt != expCnt { + t.Errorf("Unexpected 'count', expected %d, actual %d", expCnt, actCnt) + } + + t2 := h.patterns[0].timer() + if t2 != t1 { + t.Error("Unexpected 'timer', expected the same") + } + + ret = t1.Stop() + if !ret { + t.Error("Timer was unexpectedly stopped before") + } + ret = t2.Stop() + if ret { + t.Error("Timer was unexpectedly not stopped before") + } +} + +func TestStop(t *testing.T) { + buf := bytes.Buffer{} + golog.SetOutput(&buf) + + h := &errorHandler{ + patterns: []*pattern{{ + period: 2 * time.Second, + pattern: regexp.MustCompile("^error.*!$"), + logCallback: log.Errorf, + }}, + } + + h.inc(0) + h.inc(0) + h.inc(0) + expCnt := uint32(3) + actCnt := atomic.LoadUint32(&h.patterns[0].count) + if actCnt != expCnt { + t.Fatalf("Unexpected initial 'count', expected %d, actual %d", expCnt, actCnt) + } + + h.stop() + + expCnt = uint32(0) + actCnt = atomic.LoadUint32(&h.patterns[0].count) + if actCnt != expCnt { + t.Errorf("Unexpected 'count', expected %d, actual %d", expCnt, actCnt) + } + + expStop := uint32(1) + actStop := h.stopFlag + if actStop != expStop { + t.Errorf("Unexpected 'stop', expected %d, actual %d", expStop, actStop) + } + + t1 := h.patterns[0].timer() + if t1 == nil { + t.Error("Unexpected 'timer', expected not nil") + } else if t1.Stop() { + t.Error("Timer was unexpectedly not stopped before") + } + + expLog := "3 errors like '^error.*!$' occurred in last 2s" + if log := buf.String(); !strings.Contains(log, expLog) { + t.Errorf("Expected log %q, but got %q", expLog, log) + } +} + +func genErrorHandler(rcode int, err error) plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return rcode, err + }) +} diff --git a/plugin/errors/log_test.go b/plugin/errors/log_test.go new file mode 100644 index 0000000..643c16a --- /dev/null +++ b/plugin/errors/log_test.go @@ -0,0 +1,5 @@ +package errors + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/errors/setup.go b/plugin/errors/setup.go new file mode 100644 index 0000000..c040e10 --- /dev/null +++ b/plugin/errors/setup.go @@ -0,0 +1,109 @@ +package errors + +import ( + "regexp" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("errors", setup) } + +func setup(c *caddy.Controller) error { + handler, err := errorsParse(c) + if err != nil { + return plugin.Error("errors", err) + } + + c.OnShutdown(func() error { + handler.stop() + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + handler.Next = next + return handler + }) + + return nil +} + +func errorsParse(c *caddy.Controller) (*errorHandler, error) { + handler := newErrorHandler() + + i := 0 + for c.Next() { + if i > 0 { + return nil, plugin.ErrOnce + } + i++ + + args := c.RemainingArgs() + switch len(args) { + case 0: + case 1: + if args[0] != "stdout" { + return nil, c.Errf("invalid log file: %s", args[0]) + } + default: + return nil, c.ArgErr() + } + + for c.NextBlock() { + switch c.Val() { + case "stacktrace": + dnsserver.GetConfig(c).Stacktrace = true + case "consolidate": + pattern, err := parseConsolidate(c) + if err != nil { + return nil, err + } + handler.patterns = append(handler.patterns, pattern) + default: + return handler, c.SyntaxErr("Unknown field " + c.Val()) + } + } + } + return handler, nil +} + +func parseConsolidate(c *caddy.Controller) (*pattern, error) { + args := c.RemainingArgs() + if len(args) < 2 || len(args) > 3 { + return nil, c.ArgErr() + } + p, err := time.ParseDuration(args[0]) + if err != nil { + return nil, c.Err(err.Error()) + } + re, err := regexp.Compile(args[1]) + if err != nil { + return nil, c.Err(err.Error()) + } + lc, err := parseLogLevel(c, args) + if err != nil { + return nil, err + } + return &pattern{period: p, pattern: re, logCallback: lc}, nil +} + +func parseLogLevel(c *caddy.Controller, args []string) (func(format string, v ...interface{}), error) { + if len(args) != 3 { + return log.Errorf, nil + } + + switch args[2] { + case "warning": + return log.Warningf, nil + case "error": + return log.Errorf, nil + case "info": + return log.Infof, nil + case "debug": + return log.Debugf, nil + default: + return nil, c.Errf("unknown log level argument in consolidate: %s", args[2]) + } +} diff --git a/plugin/errors/setup_test.go b/plugin/errors/setup_test.go new file mode 100644 index 0000000..5dbc9ec --- /dev/null +++ b/plugin/errors/setup_test.go @@ -0,0 +1,148 @@ +package errors + +import ( + "bytes" + golog "log" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +func TestErrorsParse(t *testing.T) { + tests := []struct { + inputErrorsRules string + shouldErr bool + optCount int + stacktrace bool + }{ + {`errors`, false, 0, false}, + {`errors stdout`, false, 0, false}, + {`errors errors.txt`, true, 0, false}, + {`errors visible`, true, 0, false}, + {`errors { log visible }`, true, 0, false}, + {`errors + errors `, true, 0, false}, + {`errors a b`, true, 0, false}, + + {`errors { + consolidate + }`, true, 0, false}, + {`errors { + consolidate 1m + }`, true, 0, false}, + {`errors { + consolidate 1m .* extra + }`, true, 0, false}, + {`errors { + consolidate abc .* + }`, true, 0, false}, + {`errors { + consolidate 1 .* + }`, true, 0, false}, + {`errors { + consolidate 1m ()) + }`, true, 0, false}, + {`errors { + stacktrace + }`, false, 0, true}, + {`errors { + stacktrace + consolidate 1m ^exact$ + }`, false, 1, true}, + {`errors { + consolidate 1m ^exact$ + }`, false, 1, false}, + {`errors { + consolidate 1m error + }`, false, 1, false}, + {`errors { + consolidate 1m "format error" + }`, false, 1, false}, + {`errors { + consolidate 1m error1 + consolidate 5s error2 + }`, false, 2, false}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputErrorsRules) + h, err := errorsParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if h != nil && len(h.patterns) != test.optCount { + t.Errorf("Test %d: pattern count mismatch, expected %d, got %d", + i, test.optCount, len(h.patterns)) + } + if dnsserver.GetConfig(c).Stacktrace != test.stacktrace { + t.Errorf("Test %d: stacktrace, expected %t, got %t", + i, test.stacktrace, dnsserver.GetConfig(c).Stacktrace) + } + } +} + +func TestProperLogCallbackIsSet(t *testing.T) { + tests := []struct { + name string + inputErrorsRules string + wantLogLevel string + }{ + { + name: "warning is parsed properly", + inputErrorsRules: `errors { + consolidate 1m .* warning + }`, + wantLogLevel: "[WARNING]", + }, + { + name: "error is parsed properly", + inputErrorsRules: `errors { + consolidate 1m .* error + }`, + wantLogLevel: "[ERROR]", + }, + { + name: "info is parsed properly", + inputErrorsRules: `errors { + consolidate 1m .* info + }`, + wantLogLevel: "[INFO]", + }, + { + name: "debug is parsed properly", + inputErrorsRules: `errors { + consolidate 1m .* debug + }`, + wantLogLevel: "[DEBUG]", + }, + { + name: "default is error", + inputErrorsRules: `errors { + consolidate 1m .* + }`, + wantLogLevel: "[ERROR]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := bytes.Buffer{} + golog.SetOutput(&buf) + clog.D.Set() + + c := caddy.NewTestController("dns", tt.inputErrorsRules) + h, _ := errorsParse(c) + + l := h.patterns[0].logCallback + l("some error happened") + + if log := buf.String(); !strings.Contains(log, tt.wantLogLevel) { + t.Errorf("Expected log %q, but got %q", tt.wantLogLevel, log) + } + }) + } +} diff --git a/plugin/etcd/README.md b/plugin/etcd/README.md new file mode 100644 index 0000000..0c7f1ea --- /dev/null +++ b/plugin/etcd/README.md @@ -0,0 +1,236 @@ +# etcd + +## Name + +*etcd* - enables SkyDNS service discovery from etcd. + +## Description + +The *etcd* plugin implements the (older) SkyDNS service discovery service. It is *not* suitable as +a generic DNS zone data plugin. Only a subset of DNS record types are implemented, and subdomains +and delegations are not handled at all. The plugin will also recursively descend the tree and return +all records found, see "Special Behavior" below for details. + +The data in the etcd instance has to be encoded as +a [message](https://github.com/skynetservices/skydns/blob/2fcff74cdc9f9a7dd64189a447ef27ac354b725f/msg/service.go#L26) +like [SkyDNS](https://github.com/skynetservices/skydns). It works just like SkyDNS. + +The *etcd* plugin makes extensive use of the *forward* plugin to forward and query other servers in the +network - if that plugin has been enabled as well. + +## Syntax + +~~~ +etcd [ZONES...] +~~~ + +* **ZONES** zones *etcd* should be authoritative for. + +The path will default to `/skydns` the local etcd3 proxy (http://localhost:2379). If no zones are +specified the block's zone will be used as the zone. + + +~~~ +etcd [ZONES...] { + fallthrough [ZONES...] + path PATH + endpoint ENDPOINT... + credentials USERNAME PASSWORD + tls CERT KEY CACERT +} +~~~ + +* `fallthrough` If zone matches but no record can be generated, pass request to the next plugin. + If **[ZONES...]** is omitted, then fallthrough happens for all zones for which the plugin + is authoritative. If specific zones are listed (for example `in-addr.arpa` and `ip6.arpa`), then only + queries for those zones will be subject to fallthrough. +* **PATH** the path inside etcd. Defaults to "/skydns". +* **ENDPOINT** the etcd endpoints. Defaults to "http://localhost:2379". +* `credentials` is used to set the **USERNAME** and **PASSWORD** for accessing the etcd cluster. +* `tls` followed by: + + * no arguments, if the server certificate is signed by a system-installed CA and no client cert is needed + * a single argument that is the CA PEM file, if the server cert is not signed by a system CA and no client cert is needed + * two arguments - path to cert PEM file, the path to private key PEM file - if the server certificate is signed by a system-installed CA and a client certificate is needed + * three arguments - path to cert PEM file, path to client private key PEM file, path to CA PEM + file - if the server certificate is not signed by a system-installed CA and client certificate + is needed. + +## Special Behaviour + +The *etcd* plugin leverages directory structure to look for related entries. For example +an entry `/skydns/test/skydns/mx` would have entries like `/skydns/test/skydns/mx/a`, +`/skydns/test/skydns/mx/b` and so on. Similarly a directory `/skydns/test/skydns/mx1` will have all +`mx1` entries. Note this plugin will search through the entire (sub)tree for records. In case of the +first example, a query for `mx.skydns.test` will return both the contents of the `a` and `b` records. +If the directory extends deeper those records are returned as well. + +With etcd3, support for [hierarchical keys are +dropped](https://coreos.com/etcd/docs/latest/learning/api.html). This means there are no directories +but only flat keys with prefixes in etcd3. To accommodate lookups, the *etcd* plugin now does a lookup +on prefix `/skydns/test/skydns/mx/` to search for entries like `/skydns/test/skydns/mx/a` etc, and +if there is nothing found on `/skydns/test/skydns/mx/`, it looks for `/skydns/test/skydns/mx` to +find entries like `/skydns/test/skydns/mx1`. + +This causes two lookups from CoreDNS to etcd in certain cases. + +## Examples + +This is the default SkyDNS setup, with everything specified in full: + +~~~ corefile +skydns.local { + etcd { + path /skydns + endpoint http://localhost:2379 + } + prometheus + cache + loadbalance +} + +. { + forward . 8.8.8.8:53 8.8.4.4:53 + cache +} +~~~ + +Or a setup where we use `/etc/resolv.conf` as the basis for the proxy and the upstream +when resolving external pointing CNAMEs. + +~~~ corefile +skydns.local { + etcd { + path /skydns + } + cache +} + +. { + forward . /etc/resolv.conf + cache +} +~~~ + +Multiple endpoints are supported as well. + +~~~ +etcd skydns.local { + endpoint http://localhost:2379 http://localhost:4001 +... +~~~ +Before getting started with these examples, please setup `etcdctl` (with `etcdv3` API) as explained +[here](https://coreos.com/etcd/docs/latest/dev-guide/interacting_v3.html). This will help you to put +sample keys in your etcd server. + +If you prefer, you can use `curl` to populate the `etcd` server, but with `curl` the +endpoint URL depends on the version of `etcd`. For instance, `etcd v3.2` or before uses only +[CLIENT-URL]/v3alpha/* while `etcd v3.5` or later uses [CLIENT-URL]/v3/* . Also, Key and Value must +be base64 encoded in the JSON payload. With `etcdctl` these details are automatically taken care +of. You can check [this document](https://github.com/coreos/etcd/blob/master/Documentation/dev-guide/api_grpc_gateway.md#notes) +for details. + +### Reverse zones + +Reverse zones are supported. You need to make CoreDNS aware of the fact that you are also +authoritative for the reverse. For instance if you want to add the reverse for 10.0.0.0/24, you'll +need to add the zone `0.0.10.in-addr.arpa` to the list of zones. Showing a snippet of a Corefile: + +~~~ +etcd skydns.local 10.0.0.0/24 { +... +~~~ + +Next you'll need to populate the zone with reverse records, here we add a reverse for +10.0.0.127 pointing to reverse.skydns.local. + +~~~ +% etcdctl put /skydns/arpa/in-addr/10/0/0/127 '{"host":"reverse.skydns.local."}' +~~~ + +Querying with dig: + +~~~ sh +% dig @localhost -x 10.0.0.127 +short +reverse.skydns.local. +~~~ + +### Zone name as A record + +The zone name itself can be used as an `A` record. This behavior can be achieved by writing special +entries to the ETCD path of your zone. If your zone is named `skydns.local` for example, you can +create an `A` record for this zone as follows: + +~~~ +% etcdctl put /skydns/local/skydns/ '{"host":"1.1.1.1","ttl":60}' +~~~ + +If you query the zone name itself, you will receive the created `A` record: + +~~~ sh +% dig +short skydns.local @localhost +1.1.1.1 +~~~ + +If you would like to use DNS RR for the zone name, you can set the following: +~~~ +% etcdctl put /skydns/local/skydns/x1 '{"host":"1.1.1.1","ttl":60}' +% etcdctl put /skydns/local/skydns/x2 '{"host":"1.1.1.2","ttl":60}' +~~~ + +If you query the zone name now, you will get the following response: + +~~~ sh +% dig +short skydns.local @localhost +1.1.1.1 +1.1.1.2 +~~~ + +### Zone name as AAAA record + +If you would like to use `AAAA` records for the zone name too, you can set the following: +~~~ +% etcdctl put /skydns/local/skydns/x3 '{"host":"2003::8:1","ttl":60}' +% etcdctl put /skydns/local/skydns/x4 '{"host":"2003::8:2","ttl":60}' +~~~ + +If you query the zone name for `AAAA` now, you will get the following response: +~~~ sh +% dig +short skydns.local AAAA @localhost +2003::8:1 +2003::8:2 +~~~ + +### SRV record + +If you would like to use `SRV` records, you can set the following: +~~~ +% etcdctl put /skydns/local/skydns/x5 '{"host":"skydns-local.server","ttl":60,"priority":10,"port":8080}' +~~~ +Please notice that the key `host` is the `target` in `SRV`, so it should be a domain name. + +If you query the zone name for `SRV` now, you will get the following response: + +~~~ sh +% dig +short skydns.local SRV @localhost +10 100 8080 skydns-local.server. +~~~ + +### TXT record + +If you would like to use `TXT` records, you can set the following: +~~~ +% etcdctl put /skydns/local/skydns/x6 '{"ttl":60,"text":"this is a random text message."}' +% etcdctl put /skydns/local/skydns/x7 '{"ttl":60,"text":"this is a another random text message."}' +~~~ + +If you query the zone name for `TXT` now, you will get the following response: +~~~ sh +% dig +short skydns.local TXT @localhost +"this is a random text message." +"this is a another random text message." +~~~ + +## See Also + +If you want to `round robin` A and AAAA responses look at the *loadbalance* plugin. diff --git a/plugin/etcd/cname_test.go b/plugin/etcd/cname_test.go new file mode 100644 index 0000000..1e64d6d --- /dev/null +++ b/plugin/etcd/cname_test.go @@ -0,0 +1,108 @@ +//go:build etcd + +package etcd + +// etcd needs to be running on http://localhost:2379 + +import ( + "testing" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// Check the ordering of returned cname. +func TestCnameLookup(t *testing.T) { + etc := newEtcdPlugin() + + for _, serv := range servicesCname { + set(t, etc, serv.Key, 0, serv) + defer delete(t, etc, serv.Key) + } + for i, tc := range dnsTestCasesCname { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := etc.ServeDNS(ctxt, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.Header(tc, resp); err != nil { + t.Errorf("Test %d: %v", i, err) + continue + } + if err := test.Section(tc, test.Answer, resp.Answer); err != nil { + t.Errorf("Test %d: %v", i, err) + } + if err := test.Section(tc, test.Ns, resp.Ns); err != nil { + t.Errorf("Test %d: %v", i, err) + } + if err := test.Section(tc, test.Extra, resp.Extra); err != nil { + t.Errorf("Test %d: %v", i, err) + } + } +} + +var servicesCname = []*msg.Service{ + {Host: "cname1.region2.skydns.test", Key: "a.server1.dev.region1.skydns.test."}, + {Host: "cname2.region2.skydns.test", Key: "cname1.region2.skydns.test."}, + {Host: "cname3.region2.skydns.test", Key: "cname2.region2.skydns.test."}, + {Host: "cname4.region2.skydns.test", Key: "cname3.region2.skydns.test."}, + {Host: "cname5.region2.skydns.test", Key: "cname4.region2.skydns.test."}, + {Host: "cname6.region2.skydns.test", Key: "cname5.region2.skydns.test."}, + {Host: "endpoint.region2.skydns.test", Key: "cname6.region2.skydns.test."}, + {Host: "10.240.0.1", Key: "endpoint.region2.skydns.test."}, + + {Host: "mainendpoint.region2.skydns.test", Key: "region2.skydns.test."}, + + {Host: "cname2.region3.skydns.test", Key: "cname3.region3.skydns.test."}, + {Host: "cname1.region3.skydns.test", Key: "cname2.region3.skydns.test."}, + {Host: "endpoint.region3.skydns.test", Key: "cname1.region3.skydns.test."}, + {Host: "", Key: "endpoint.region3.skydns.test.", Text: "SOME-RECORD-TEXT"}, +} + +var dnsTestCasesCname = []test.Case{ + { // Test 0 + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + test.SRV("a.server1.dev.region1.skydns.test. 300 IN SRV 10 100 0 cname1.region2.skydns.test."), + }, + Extra: []dns.RR{ + test.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.CNAME("cname3.region2.skydns.test. 300 IN CNAME cname4.region2.skydns.test."), + test.CNAME("cname4.region2.skydns.test. 300 IN CNAME cname5.region2.skydns.test."), + test.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + test.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + }, + }, + { // Test 1 + Qname: "region2.skydns.test.", Qtype: dns.TypeCNAME, + Answer: []dns.RR{ + test.CNAME("region2.skydns.test. 300 IN CNAME mainendpoint.region2.skydns.test."), + }, + }, + { // Test 2 + Qname: "endpoint.region3.skydns.test.", Qtype: dns.TypeCNAME, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("skydns.test. 303 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1546424605 7200 1800 86400 30"), + }, + }, + { // Test 3 + Qname: "cname3.region3.skydns.test.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.CNAME("cname3.region3.skydns.test. 300 IN CNAME cname2.region3.skydns.test."), + test.CNAME("cname2.region3.skydns.test. 300 IN CNAME cname1.region3.skydns.test."), + test.CNAME("cname1.region3.skydns.test. 300 IN CNAME endpoint.region3.skydns.test."), + test.TXT("endpoint.region3.skydns.test. 300 IN TXT \"SOME-RECORD-TEXT\""), + }, + }, +} diff --git a/plugin/etcd/etcd.go b/plugin/etcd/etcd.go new file mode 100644 index 0000000..077e490 --- /dev/null +++ b/plugin/etcd/etcd.go @@ -0,0 +1,185 @@ +// Package etcd provides the etcd version 3 backend plugin. +package etcd + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "go.etcd.io/etcd/api/v3/mvccpb" + etcdcv3 "go.etcd.io/etcd/client/v3" +) + +const ( + priority = 10 // default priority when nothing is set + ttl = 300 // default ttl when nothing is set + etcdTimeout = 5 * time.Second +) + +var errKeyNotFound = errors.New("key not found") + +// Etcd is a plugin talks to an etcd cluster. +type Etcd struct { + Next plugin.Handler + Fall fall.F + Zones []string + PathPrefix string + Upstream *upstream.Upstream + Client *etcdcv3.Client + + endpoints []string // Stored here as well, to aid in testing. +} + +// Services implements the ServiceBackend interface. +func (e *Etcd) Services(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (services []msg.Service, err error) { + services, err = e.Records(ctx, state, exact) + if err != nil { + return + } + + services = msg.Group(services) + return +} + +// Reverse implements the ServiceBackend interface. +func (e *Etcd) Reverse(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (services []msg.Service, err error) { + return e.Services(ctx, state, exact, opt) +} + +// Lookup implements the ServiceBackend interface. +func (e *Etcd) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + return e.Upstream.Lookup(ctx, state, name, typ) +} + +// IsNameError implements the ServiceBackend interface. +func (e *Etcd) IsNameError(err error) bool { + return err == errKeyNotFound +} + +// Records looks up records in etcd. If exact is true, it will lookup just this +// name. This is used when find matches when completing SRV lookups for instance. +func (e *Etcd) Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error) { + name := state.Name() + + path, star := msg.PathWithWildcard(name, e.PathPrefix) + r, err := e.get(ctx, path, !exact) + if err != nil { + return nil, err + } + segments := strings.Split(msg.Path(name, e.PathPrefix), "/") + return e.loopNodes(r.Kvs, segments, star, state.QType()) +} + +func (e *Etcd) get(ctx context.Context, path string, recursive bool) (*etcdcv3.GetResponse, error) { + ctx, cancel := context.WithTimeout(ctx, etcdTimeout) + defer cancel() + if recursive { + if !strings.HasSuffix(path, "/") { + path = path + "/" + } + r, err := e.Client.Get(ctx, path, etcdcv3.WithPrefix()) + if err != nil { + return nil, err + } + if r.Count == 0 { + path = strings.TrimSuffix(path, "/") + r, err = e.Client.Get(ctx, path) + if err != nil { + return nil, err + } + if r.Count == 0 { + return nil, errKeyNotFound + } + } + return r, nil + } + + r, err := e.Client.Get(ctx, path) + if err != nil { + return nil, err + } + if r.Count == 0 { + return nil, errKeyNotFound + } + return r, nil +} + +func (e *Etcd) loopNodes(kv []*mvccpb.KeyValue, nameParts []string, star bool, qType uint16) (sx []msg.Service, err error) { + bx := make(map[msg.Service]struct{}) +Nodes: + for _, n := range kv { + if star { + s := string(n.Key) + keyParts := strings.Split(s, "/") + for i, n := range nameParts { + if i > len(keyParts)-1 { + // name is longer than key + continue Nodes + } + if n == "*" || n == "any" { + continue + } + if keyParts[i] != n { + continue Nodes + } + } + } + serv := new(msg.Service) + if err := json.Unmarshal(n.Value, serv); err != nil { + return nil, fmt.Errorf("%s: %s", n.Key, err.Error()) + } + serv.Key = string(n.Key) + if _, ok := bx[*serv]; ok { + continue + } + bx[*serv] = struct{}{} + + serv.TTL = e.TTL(n, serv) + if serv.Priority == 0 { + serv.Priority = priority + } + + if shouldInclude(serv, qType) { + sx = append(sx, *serv) + } + } + return sx, nil +} + +// TTL returns the smaller of the etcd TTL and the service's +// TTL. If neither of these are set (have a zero value), a default is used. +func (e *Etcd) TTL(kv *mvccpb.KeyValue, serv *msg.Service) uint32 { + etcdTTL := uint32(kv.Lease) + + if etcdTTL == 0 && serv.TTL == 0 { + return ttl + } + if etcdTTL == 0 { + return serv.TTL + } + if serv.TTL == 0 { + return etcdTTL + } + if etcdTTL < serv.TTL { + return etcdTTL + } + return serv.TTL +} + +// shouldInclude returns true if the service should be included in a list of records, given the qType. For all the +// currently supported lookup types, the only one to allow for an empty Host field in the service are TXT records +// which resolve directly. If a TXT record is being resolved by CNAME, then we expect the Host field to have a +// value while the TXT field will be empty. +func shouldInclude(serv *msg.Service, qType uint16) bool { + return (qType == dns.TypeTXT && serv.Text != "") || serv.Host != "" +} diff --git a/plugin/etcd/group_test.go b/plugin/etcd/group_test.go new file mode 100644 index 0000000..2620bf2 --- /dev/null +++ b/plugin/etcd/group_test.go @@ -0,0 +1,87 @@ +//go:build etcd + +package etcd + +import ( + "testing" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestGroupLookup(t *testing.T) { + etc := newEtcdPlugin() + + for _, serv := range servicesGroup { + set(t, etc, serv.Key, 0, serv) + defer delete(t, etc, serv.Key) + } + for _, tc := range dnsTestCasesGroup { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := etc.ServeDNS(ctxt, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + continue + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +// Note the key is encoded as DNS name, while in "reality" it is a etcd path. +var servicesGroup = []*msg.Service{ + {Host: "127.0.0.1", Key: "a.dom.skydns.test.", Group: "g1"}, + {Host: "127.0.0.2", Key: "b.sub.dom.skydns.test.", Group: "g1"}, + + {Host: "127.0.0.1", Key: "a.dom2.skydns.test.", Group: "g1"}, + {Host: "127.0.0.2", Key: "b.sub.dom2.skydns.test.", Group: ""}, + + {Host: "127.0.0.1", Key: "a.dom1.skydns.test.", Group: "g1"}, + {Host: "127.0.0.2", Key: "b.sub.dom1.skydns.test.", Group: "g2"}, + + {Text: "foo", Key: "a.dom3.skydns.test.", Group: "g1"}, + {Text: "bar", Key: "b.sub.dom3.skydns.test.", Group: "g1"}, +} + +var dnsTestCasesGroup = []test.Case{ + // Groups + { + // hits the group 'g1' and only includes those A records + Qname: "dom.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("dom.skydns.test. 300 IN A 127.0.0.1"), + test.A("dom.skydns.test. 300 IN A 127.0.0.2"), + }, + }, + { + // One has group, the other has not... Include the non-group always. + Qname: "dom2.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("dom2.skydns.test. 300 IN A 127.0.0.1"), + test.A("dom2.skydns.test. 300 IN A 127.0.0.2"), + }, + }, + { + // The groups differ. + Qname: "dom1.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("dom1.skydns.test. 300 IN A 127.0.0.1"), + }, + }, + { + // hits the group 'g1' and only includes those TXT records + Qname: "dom3.skydns.test.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT("dom3.skydns.test. 300 IN TXT bar"), + test.TXT("dom3.skydns.test. 300 IN TXT foo"), + }, + }, +} diff --git a/plugin/etcd/handler.go b/plugin/etcd/handler.go new file mode 100644 index 0000000..5a99753 --- /dev/null +++ b/plugin/etcd/handler.go @@ -0,0 +1,82 @@ +package etcd + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServeDNS implements the plugin.Handler interface. +func (e *Etcd) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + opt := plugin.Options{} + state := request.Request{W: w, Req: r} + + zone := plugin.Zones(e.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r) + } + + var ( + records, extra []dns.RR + truncated bool + err error + ) + + switch state.QType() { + case dns.TypeA: + records, truncated, err = plugin.A(ctx, e, zone, state, nil, opt) + case dns.TypeAAAA: + records, truncated, err = plugin.AAAA(ctx, e, zone, state, nil, opt) + case dns.TypeTXT: + records, truncated, err = plugin.TXT(ctx, e, zone, state, nil, opt) + case dns.TypeCNAME: + records, err = plugin.CNAME(ctx, e, zone, state, opt) + case dns.TypePTR: + records, err = plugin.PTR(ctx, e, zone, state, opt) + case dns.TypeMX: + records, extra, err = plugin.MX(ctx, e, zone, state, opt) + case dns.TypeSRV: + records, extra, err = plugin.SRV(ctx, e, zone, state, opt) + case dns.TypeSOA: + records, err = plugin.SOA(ctx, e, zone, state, opt) + case dns.TypeNS: + if state.Name() == zone { + records, extra, err = plugin.NS(ctx, e, zone, state, opt) + break + } + fallthrough + default: + // Do a fake A lookup, so we can distinguish between NODATA and NXDOMAIN + _, _, err = plugin.A(ctx, e, zone, state, nil, opt) + } + if err != nil && e.IsNameError(err) { + if e.Fall.Through(state.Name()) { + return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r) + } + // Make err nil when returning here, so we don't log spam for NXDOMAIN. + return plugin.BackendError(ctx, e, zone, dns.RcodeNameError, state, nil /* err */, opt) + } + if err != nil { + return plugin.BackendError(ctx, e, zone, dns.RcodeServerFailure, state, err, opt) + } + + if len(records) == 0 { + return plugin.BackendError(ctx, e, zone, dns.RcodeSuccess, state, err, opt) + } + + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = truncated + m.Authoritative = true + m.Answer = append(m.Answer, records...) + m.Extra = append(m.Extra, extra...) + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +// Name implements the Handler interface. +func (e *Etcd) Name() string { return "etcd" } diff --git a/plugin/etcd/log_test.go b/plugin/etcd/log_test.go new file mode 100644 index 0000000..57735be --- /dev/null +++ b/plugin/etcd/log_test.go @@ -0,0 +1,5 @@ +package etcd + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/etcd/lookup_test.go b/plugin/etcd/lookup_test.go new file mode 100644 index 0000000..0b689b0 --- /dev/null +++ b/plugin/etcd/lookup_test.go @@ -0,0 +1,355 @@ +//go:build etcd + +package etcd + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/tls" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func init() { + ctxt = context.TODO() +} + +// Note the key is encoded as DNS name, while in "reality" it is a etcd path. +var services = []*msg.Service{ + {Host: "dev.server1", Port: 8080, Key: "a.server1.dev.region1.skydns.test."}, + {Host: "10.0.0.1", Port: 8080, Key: "a.server1.prod.region1.skydns.test."}, + {Host: "10.0.0.2", Port: 8080, Key: "b.server1.prod.region1.skydns.test."}, + {Host: "::1", Port: 8080, Key: "b.server6.prod.region1.skydns.test."}, + // TXT record in server1. + {Text: "sometext", Key: "a.txt.server1.prod.region1.skydns.test."}, + {Text: "moretext", Key: "b.txt.server1.prod.region1.skydns.test."}, + // Unresolvable internal name. + {Host: "unresolvable.skydns.test", Key: "cname.prod.region1.skydns.test."}, + // Priority. + {Host: "priority.server1", Priority: 333, Port: 8080, Key: "priority.skydns.test."}, + // Subdomain. + {Host: "sub.server1", Port: 0, Key: "a.sub.region1.skydns.test."}, + {Host: "sub.server2", Port: 80, Key: "b.sub.region1.skydns.test."}, + {Host: "10.0.0.1", Port: 8080, Key: "c.sub.region1.skydns.test."}, + // TargetStrip. + {Host: "10.0.0.1", Port: 8080, Key: "a.targetstrip.skydns.test.", TargetStrip: 1}, + // Cname loop. + {Host: "a.cname.skydns.test", Key: "b.cname.skydns.test."}, + {Host: "b.cname.skydns.test", Key: "a.cname.skydns.test."}, + // Nameservers. + {Host: "10.0.0.2", Key: "a.ns.dns.skydns.test."}, + {Host: "10.0.0.3", Key: "b.ns.dns.skydns.test."}, + {Host: "10.0.0.4", Key: "ns1.c.ns.dns.skydns.test.", TargetStrip: 1}, + {Host: "10.0.0.5", Key: "ns2.c.ns.dns.skydns.test.", TargetStrip: 1}, + // Zone name as A record (basic, return all) + {Host: "10.0.0.2", Key: "x.skydns_zonea.test."}, + {Host: "10.0.0.3", Key: "y.skydns_zonea.test."}, + // Zone name as A (single entry). + {Host: "10.0.0.2", Key: "x.skydns_zoneb.test."}, + {Host: "10.0.0.3", Key: "y.skydns_zoneb.test."}, + {Host: "10.0.0.4", Key: "apex.dns.skydns_zoneb.test."}, + // A zone record (rr multiple entries). + {Host: "10.0.0.2", Key: "x.skydns_zonec.test."}, + {Host: "10.0.0.3", Key: "y.skydns_zonec.test."}, + {Host: "10.0.0.4", Key: "a1.apex.dns.skydns_zonec.test."}, + {Host: "10.0.0.5", Key: "a2.apex.dns.skydns_zonec.test."}, + // AAAA zone record (rr multiple entries mixed with A). + {Host: "10.0.0.2", Key: "x.skydns_zoned.test."}, + {Host: "10.0.0.3", Key: "y.skydns_zoned.test."}, + {Host: "10.0.0.4", Key: "a1.apex.dns.skydns_zoned.test."}, + {Host: "10.0.0.5", Key: "a2.apex.dns.skydns_zoned.test."}, + {Host: "2003::8:1", Key: "a3.apex.dns.skydns_zoned.test."}, + {Host: "2003::8:2", Key: "a4.apex.dns.skydns_zoned.test."}, + // Reverse. + {Host: "reverse.example.com", Key: "1.0.0.10.in-addr.arpa."}, // 10.0.0.1 +} + +var dnsTestCases = []test.Case{ + // SRV Test + { + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{test.SRV("a.server1.dev.region1.skydns.test. 300 SRV 10 100 8080 dev.server1.")}, + }, + // SRV Test (case test) + { + Qname: "a.SERVer1.dEv.region1.skydns.tEst.", Qtype: dns.TypeSRV, + Answer: []dns.RR{test.SRV("a.SERVer1.dEv.region1.skydns.tEst. 300 SRV 10 100 8080 dev.server1.")}, + }, + // NXDOMAIN Test + { + Qname: "doesnotexist.skydns.test.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 0 0 0 0 0"), + }, + }, + // A Test + { + Qname: "a.server1.prod.region1.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A("a.server1.prod.region1.skydns.test. 300 A 10.0.0.1")}, + }, + // SRV Test where target is IP address + { + Qname: "a.server1.prod.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{test.SRV("a.server1.prod.region1.skydns.test. 300 SRV 10 100 8080 a.server1.prod.region1.skydns.test.")}, + Extra: []dns.RR{test.A("a.server1.prod.region1.skydns.test. 300 A 10.0.0.1")}, + }, + // AAAA Test + { + Qname: "b.server6.prod.region1.skydns.test.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{test.AAAA("b.server6.prod.region1.skydns.test. 300 AAAA ::1")}, + }, + // Multiple A Record Test + { + Qname: "server1.prod.region1.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("server1.prod.region1.skydns.test. 300 A 10.0.0.1"), + test.A("server1.prod.region1.skydns.test. 300 A 10.0.0.2"), + }, + }, + // Priority Test + { + Qname: "priority.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{test.SRV("priority.skydns.test. 300 SRV 333 100 8080 priority.server1.")}, + }, + // Subdomain Test + { + Qname: "sub.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + test.SRV("sub.region1.skydns.test. 300 IN SRV 10 33 0 sub.server1."), + test.SRV("sub.region1.skydns.test. 300 IN SRV 10 33 80 sub.server2."), + test.SRV("sub.region1.skydns.test. 300 IN SRV 10 33 8080 c.sub.region1.skydns.test."), + }, + Extra: []dns.RR{test.A("c.sub.region1.skydns.test. 300 IN A 10.0.0.1")}, + }, + // SRV TargetStrip Test + { + Qname: "targetstrip.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + test.SRV("targetstrip.skydns.test. 300 IN SRV 10 100 8080 targetstrip.skydns.test."), + }, + Extra: []dns.RR{test.A("targetstrip.skydns.test. 300 IN A 10.0.0.1")}, + }, + // CNAME (unresolvable internal name) + { + Qname: "cname.prod.region1.skydns.test.", Qtype: dns.TypeA, + Ns: []dns.RR{test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 0 0 0 0 0")}, + }, + // TXT Test + { + Qname: "txt.server1.prod.region1.skydns.test.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT("txt.server1.prod.region1.skydns.test. 303 IN TXT moretext"), + test.TXT("txt.server1.prod.region1.skydns.test. 303 IN TXT sometext"), + }, + }, + // Wildcard Test + { + Qname: "*.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 0 sub.server1."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 0 unresolvable.skydns.test."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 80 sub.server2."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 a.server1.prod.region1.skydns.test."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 b.server1.prod.region1.skydns.test."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 b.server6.prod.region1.skydns.test."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 c.sub.region1.skydns.test."), + test.SRV("*.region1.skydns.test. 300 IN SRV 10 12 8080 dev.server1."), + }, + Extra: []dns.RR{ + test.A("a.server1.prod.region1.skydns.test. 300 IN A 10.0.0.1"), + test.A("b.server1.prod.region1.skydns.test. 300 IN A 10.0.0.2"), + test.AAAA("b.server6.prod.region1.skydns.test. 300 IN AAAA ::1"), + test.A("c.sub.region1.skydns.test. 300 IN A 10.0.0.1"), + }, + }, + // Wildcard Test + { + Qname: "prod.*.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + + test.SRV("prod.*.skydns.test. 300 IN SRV 10 25 0 unresolvable.skydns.test."), + test.SRV("prod.*.skydns.test. 300 IN SRV 10 25 8080 a.server1.prod.region1.skydns.test."), + test.SRV("prod.*.skydns.test. 300 IN SRV 10 25 8080 b.server1.prod.region1.skydns.test."), + test.SRV("prod.*.skydns.test. 300 IN SRV 10 25 8080 b.server6.prod.region1.skydns.test."), + }, + Extra: []dns.RR{ + test.A("a.server1.prod.region1.skydns.test. 300 IN A 10.0.0.1"), + test.A("b.server1.prod.region1.skydns.test. 300 IN A 10.0.0.2"), + test.AAAA("b.server6.prod.region1.skydns.test. 300 IN AAAA ::1"), + }, + }, + // Wildcard Test + { + Qname: "prod.any.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + test.SRV("prod.any.skydns.test. 300 IN SRV 10 25 0 unresolvable.skydns.test."), + test.SRV("prod.any.skydns.test. 300 IN SRV 10 25 8080 a.server1.prod.region1.skydns.test."), + test.SRV("prod.any.skydns.test. 300 IN SRV 10 25 8080 b.server1.prod.region1.skydns.test."), + test.SRV("prod.any.skydns.test. 300 IN SRV 10 25 8080 b.server6.prod.region1.skydns.test."), + }, + Extra: []dns.RR{ + test.A("a.server1.prod.region1.skydns.test. 300 IN A 10.0.0.1"), + test.A("b.server1.prod.region1.skydns.test. 300 IN A 10.0.0.2"), + test.AAAA("b.server6.prod.region1.skydns.test. 300 IN AAAA ::1"), + }, + }, + // CNAME loop detection + { + Qname: "a.cname.skydns.test.", Qtype: dns.TypeA, + Ns: []dns.RR{test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 1407441600 28800 7200 604800 60")}, + }, + // NODATA Test + { + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeTXT, + Ns: []dns.RR{test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 0 0 0 0 0")}, + }, + // NODATA Test + { + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeHINFO, + Ns: []dns.RR{test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 0 0 0 0 0")}, + }, + // NXDOMAIN Test + { + Qname: "a.server1.nonexistent.region1.skydns.test.", Qtype: dns.TypeHINFO, Rcode: dns.RcodeNameError, + Ns: []dns.RR{test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 0 0 0 0 0")}, + }, + { + Qname: "skydns.test.", Qtype: dns.TypeSOA, + Answer: []dns.RR{test.SOA("skydns.test. 30 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1460498836 14400 3600 604800 60")}, + }, + // NS Record Test + { + Qname: "skydns.test.", Qtype: dns.TypeNS, + Answer: []dns.RR{ + test.NS("skydns.test. 300 NS a.ns.dns.skydns.test."), + test.NS("skydns.test. 300 NS b.ns.dns.skydns.test."), + test.NS("skydns.test. 300 NS c.ns.dns.skydns.test."), + }, + Extra: []dns.RR{ + test.A("a.ns.dns.skydns.test. 300 A 10.0.0.2"), + test.A("b.ns.dns.skydns.test. 300 A 10.0.0.3"), + test.A("c.ns.dns.skydns.test. 300 A 10.0.0.4"), + test.A("c.ns.dns.skydns.test. 300 A 10.0.0.5"), + }, + }, + // NS Record Test + { + Qname: "a.skydns.test.", Qtype: dns.TypeNS, Rcode: dns.RcodeNameError, + Ns: []dns.RR{test.SOA("skydns.test. 30 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1460498836 14400 3600 604800 60")}, + }, + // A Record For NS Record Test + { + Qname: "ns.dns.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("ns.dns.skydns.test. 300 A 10.0.0.2"), + test.A("ns.dns.skydns.test. 300 A 10.0.0.3"), + test.A("ns.dns.skydns.test. 300 A 10.0.0.4"), + test.A("ns.dns.skydns.test. 300 A 10.0.0.5"), + }, + }, + { + Qname: "skydns_extra.test.", Qtype: dns.TypeSOA, + Answer: []dns.RR{test.SOA("skydns_extra.test. 30 IN SOA ns.dns.skydns_extra.test. hostmaster.skydns_extra.test. 1460498836 14400 3600 604800 60")}, + }, + // A Record Test for backward compatibility for zone records + { + Qname: "skydns_zonea.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("skydns_zonea.test. 300 A 10.0.0.2"), + test.A("skydns_zonea.test. 300 A 10.0.0.3"), + }, + }, + // A Record Test for single A zone record + { + Qname: "skydns_zoneb.test.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A("skydns_zoneb.test. 300 A 10.0.0.4")}, + }, + // A Record Test for multiple A zone records + { + Qname: "skydns_zonec.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("skydns_zonec.test. 300 A 10.0.0.4"), + test.A("skydns_zonec.test. 300 A 10.0.0.5"), + }, + }, + // A Record Test for multiple mixed A and AAAA records + { + Qname: "skydns_zoned.test.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("skydns_zoned.test. 300 A 10.0.0.4"), + test.A("skydns_zoned.test. 300 A 10.0.0.5"), + }, + }, + // AAAA Record Test for multiple mixed A and AAAA records + { + Qname: "skydns_zoned.test.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{ + test.AAAA("skydns_zoned.test. 300 AAAA 2003::8:1"), + test.AAAA("skydns_zoned.test. 300 AAAA 2003::8:2"), + }, + }, + // Reverse lookup + { + Qname: "1.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR, + Answer: []dns.RR{test.PTR("1.0.0.10.in-addr.arpa. 300 PTR reverse.example.com.")}, + }, +} + +func newEtcdPlugin() *Etcd { + ctxt = context.TODO() + + endpoints := []string{"http://localhost:2379"} + tlsc, _ := tls.NewTLSConfigFromArgs() + client, _ := newEtcdClient(endpoints, tlsc, "", "") + + return &Etcd{ + Upstream: upstream.New(), + PathPrefix: "skydns", + Zones: []string{"skydns.test.", "skydns_extra.test.", "skydns_zonea.test.", "skydns_zoneb.test.", "skydns_zonec.test.", "skydns_zoned.test.", "in-addr.arpa."}, + Client: client, + } +} + +func set(t *testing.T, e *Etcd, k string, ttl time.Duration, m *msg.Service) { + b, err := json.Marshal(m) + if err != nil { + t.Fatal(err) + } + path, _ := msg.PathWithWildcard(k, e.PathPrefix) + e.Client.KV.Put(ctxt, path, string(b)) +} + +func delete(t *testing.T, e *Etcd, k string) { + path, _ := msg.PathWithWildcard(k, e.PathPrefix) + e.Client.Delete(ctxt, path) +} + +func TestLookup(t *testing.T) { + etc := newEtcdPlugin() + for _, serv := range services { + set(t, etc, serv.Key, 0, serv) + defer delete(t, etc, serv.Key) + } + + for i, tc := range dnsTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + etc.ServeDNS(ctxt, rec, m) + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Errorf("Test %d: %v", i, err) + } + } +} + +var ctxt context.Context diff --git a/plugin/etcd/msg/path.go b/plugin/etcd/msg/path.go new file mode 100644 index 0000000..2c6cbff --- /dev/null +++ b/plugin/etcd/msg/path.go @@ -0,0 +1,51 @@ +package msg + +import ( + "path" + "strings" + + "github.com/coredns/coredns/plugin/pkg/dnsutil" + + "github.com/miekg/dns" +) + +// Path converts a domainname to an etcd path. If s looks like service.staging.skydns.local., +// the resulting key will be /skydns/local/skydns/staging/service . +func Path(s, prefix string) string { + l := dns.SplitDomainName(s) + for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + return path.Join(append([]string{"/" + prefix + "/"}, l...)...) +} + +// Domain is the opposite of Path. +func Domain(s string) string { + l := strings.Split(s, "/") + if l[len(l)-1] == "" { + l = l[:len(l)-1] + } + // start with 1, to strip /skydns + for i, j := 1, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + return dnsutil.Join(l[1 : len(l)-1]...) +} + +// PathWithWildcard acts as Path, but if a name contains wildcards (* or any), the name will be +// chopped of before the (first) wildcard, and we do a higher level search and +// later find the matching names. So service.*.skydns.local, will look for all +// services under skydns.local and will later check for names that match +// service.*.skydns.local. If a wildcard is found the returned bool is true. +func PathWithWildcard(s, prefix string) (string, bool) { + l := dns.SplitDomainName(s) + for i, j := 0, len(l)-1; i < j; i, j = i+1, j-1 { + l[i], l[j] = l[j], l[i] + } + for i, k := range l { + if k == "*" || k == "any" { + return path.Join(append([]string{"/" + prefix + "/"}, l[:i]...)...), true + } + } + return path.Join(append([]string{"/" + prefix + "/"}, l...)...), false +} diff --git a/plugin/etcd/msg/path_test.go b/plugin/etcd/msg/path_test.go new file mode 100644 index 0000000..a20d783 --- /dev/null +++ b/plugin/etcd/msg/path_test.go @@ -0,0 +1,24 @@ +package msg + +import "testing" + +func TestPath(t *testing.T) { + for _, path := range []string{"mydns", "skydns"} { + result := Path("service.staging.skydns.local.", path) + if result != "/"+path+"/local/skydns/staging/service" { + t.Errorf("Failure to get domain's path with prefix: %s", result) + } + } +} + +func TestDomain(t *testing.T) { + result1 := Domain("/skydns/local/cluster/staging/service/") + if result1 != "service.staging.cluster.local." { + t.Errorf("Failure to get domain from etcd key (with a trailing '/'), expect: 'service.staging.cluster.local.', actually get: '%s'", result1) + } + + result2 := Domain("/skydns/local/cluster/staging/service") + if result2 != "service.staging.cluster.local." { + t.Errorf("Failure to get domain from etcd key (without trailing '/'), expect: 'service.staging.cluster.local.' actually get: '%s'", result2) + } +} diff --git a/plugin/etcd/msg/service.go b/plugin/etcd/msg/service.go new file mode 100644 index 0000000..759a862 --- /dev/null +++ b/plugin/etcd/msg/service.go @@ -0,0 +1,176 @@ +// Package msg defines the Service structure which is used for service discovery. +package msg + +import ( + "net" + "strings" + + "github.com/miekg/dns" +) + +// Service defines a discoverable service in etcd. It is the rdata from a SRV +// record, but with a twist. Host (Target in SRV) must be a domain name, but +// if it looks like an IP address (4/6), we will treat it like an IP address. +type Service struct { + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + Priority int `json:"priority,omitempty"` + Weight int `json:"weight,omitempty"` + Text string `json:"text,omitempty"` + Mail bool `json:"mail,omitempty"` // Be an MX record. Priority becomes Preference. + TTL uint32 `json:"ttl,omitempty"` + + // When a SRV record with a "Host: IP-address" is added, we synthesize + // a srv.Target domain name. Normally we convert the full Key where + // the record lives to a DNS name and use this as the srv.Target. When + // TargetStrip > 0 we strip the left most TargetStrip labels from the + // DNS name. + TargetStrip int `json:"targetstrip,omitempty"` + + // Group is used to group (or *not* to group) different services + // together. Services with an identical Group are returned in the same + // answer. + Group string `json:"group,omitempty"` + + // Etcd key where we found this service and ignored from json un-/marshalling + Key string `json:"-"` +} + +// NewSRV returns a new SRV record based on the Service. +func (s *Service) NewSRV(name string, weight uint16) *dns.SRV { + host := dns.Fqdn(s.Host) + if s.TargetStrip > 0 { + host = targetStrip(host, s.TargetStrip) + } + + return &dns.SRV{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: s.TTL}, + Priority: uint16(s.Priority), Weight: weight, Port: uint16(s.Port), Target: host} +} + +// NewMX returns a new MX record based on the Service. +func (s *Service) NewMX(name string) *dns.MX { + host := dns.Fqdn(s.Host) + if s.TargetStrip > 0 { + host = targetStrip(host, s.TargetStrip) + } + + return &dns.MX{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: s.TTL}, + Preference: uint16(s.Priority), Mx: host} +} + +// NewA returns a new A record based on the Service. +func (s *Service) NewA(name string, ip net.IP) *dns.A { + return &dns.A{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.TTL}, A: ip} +} + +// NewAAAA returns a new AAAA record based on the Service. +func (s *Service) NewAAAA(name string, ip net.IP) *dns.AAAA { + return &dns.AAAA{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.TTL}, AAAA: ip} +} + +// NewCNAME returns a new CNAME record based on the Service. +func (s *Service) NewCNAME(name string, target string) *dns.CNAME { + return &dns.CNAME{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: s.TTL}, Target: dns.Fqdn(target)} +} + +// NewTXT returns a new TXT record based on the Service. +func (s *Service) NewTXT(name string) *dns.TXT { + return &dns.TXT{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: s.TTL}, Txt: split255(s.Text)} +} + +// NewPTR returns a new PTR record based on the Service. +func (s *Service) NewPTR(name string, target string) *dns.PTR { + return &dns.PTR{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: s.TTL}, Ptr: dns.Fqdn(target)} +} + +// NewNS returns a new NS record based on the Service. +func (s *Service) NewNS(name string) *dns.NS { + host := dns.Fqdn(s.Host) + if s.TargetStrip > 0 { + host = targetStrip(host, s.TargetStrip) + } + return &dns.NS{Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: s.TTL}, Ns: host} +} + +// Group checks the services in sx, it looks for a Group attribute on the shortest +// keys. If there are multiple shortest keys *and* the group attribute disagrees (and +// is not empty), we don't consider it a group. +// If a group is found, only services with *that* group (or no group) will be returned. +func Group(sx []Service) []Service { + if len(sx) == 0 { + return sx + } + + // Shortest key with group attribute sets the group for this set. + group := sx[0].Group + slashes := strings.Count(sx[0].Key, "/") + length := make([]int, len(sx)) + for i, s := range sx { + x := strings.Count(s.Key, "/") + length[i] = x + if x < slashes { + if s.Group == "" { + break + } + slashes = x + group = s.Group + } + } + + if group == "" { + return sx + } + + ret := []Service{} // with slice-tricks in sx we can prolly save this allocation (TODO) + + for i, s := range sx { + if s.Group == "" { + ret = append(ret, s) + continue + } + + // Disagreement on the same level + if length[i] == slashes && s.Group != group { + return sx + } + + if s.Group == group { + ret = append(ret, s) + } + } + return ret +} + +// Split255 splits a string into 255 byte chunks. +func split255(s string) []string { + if len(s) < 255 { + return []string{s} + } + sx := []string{} + p, i := 0, 255 + for { + if i <= len(s) { + sx = append(sx, s[p:i]) + } else { + sx = append(sx, s[p:]) + break + } + p, i = p+255, i+255 + } + + return sx +} + +// targetStrip strips "targetstrip" labels from the left side of the fully qualified name. +func targetStrip(name string, targetStrip int) string { + offset, end := 0, false + for i := 0; i < targetStrip; i++ { + offset, end = dns.NextLabel(name, offset) + } + if end { + // We overshot the name, use the original one. + offset = 0 + } + name = name[offset:] + return name +} diff --git a/plugin/etcd/msg/service_test.go b/plugin/etcd/msg/service_test.go new file mode 100644 index 0000000..f334aa5 --- /dev/null +++ b/plugin/etcd/msg/service_test.go @@ -0,0 +1,125 @@ +package msg + +import "testing" + +func TestSplit255(t *testing.T) { + xs := split255("abc") + if len(xs) != 1 && xs[0] != "abc" { + t.Errorf("Failure to split abc") + } + s := "" + for i := 0; i < 255; i++ { + s += "a" + } + xs = split255(s) + if len(xs) != 1 && xs[0] != s { + t.Errorf("Failure to split 255 char long string") + } + s += "b" + xs = split255(s) + if len(xs) != 2 || xs[1] != "b" { + t.Errorf("Failure to split 256 char long string: %d", len(xs)) + } + for i := 0; i < 255; i++ { + s += "a" + } + xs = split255(s) + if len(xs) != 3 || xs[2] != "a" { + t.Errorf("Failure to split 510 char long string: %d", len(xs)) + } +} + +func TestGroup(t *testing.T) { + // Key are in the wrong order, but for this test it does not matter. + sx := Group( + []Service{ + {Host: "127.0.0.1", Group: "g1", Key: "b/sub/dom1/skydns/test"}, + {Host: "127.0.0.2", Group: "g2", Key: "a/dom1/skydns/test"}, + }, + ) + // Expecting to return the shortest key with a Group attribute. + if len(sx) != 1 { + t.Fatalf("Failure to group zeroth set: %v", sx) + } + if sx[0].Key != "a/dom1/skydns/test" { + t.Fatalf("Failure to group zeroth set: %v, wrong Key", sx) + } + + // Groups disagree, so we will not do anything. + sx = Group( + []Service{ + {Host: "server1", Group: "g1", Key: "region1/skydns/test"}, + {Host: "server2", Group: "g2", Key: "region1/skydns/test"}, + }, + ) + if len(sx) != 2 { + t.Fatalf("Failure to group first set: %v", sx) + } + + // Group is g1, include only the top-level one. + sx = Group( + []Service{ + {Host: "server1", Group: "g1", Key: "a/dom/region1/skydns/test"}, + {Host: "server2", Group: "g2", Key: "a/subdom/dom/region1/skydns/test"}, + }, + ) + if len(sx) != 1 { + t.Fatalf("Failure to group second set: %v", sx) + } + + // Groupless services must be included. + sx = Group( + []Service{ + {Host: "server1", Group: "g1", Key: "a/dom/region1/skydns/test"}, + {Host: "server2", Group: "g2", Key: "a/subdom/dom/region1/skydns/test"}, + {Host: "server2", Group: "", Key: "b/subdom/dom/region1/skydns/test"}, + }, + ) + if len(sx) != 2 { + t.Fatalf("Failure to group third set: %v", sx) + } + + // Empty group on the highest level: include that one also. + sx = Group( + []Service{ + {Host: "server1", Group: "g1", Key: "a/dom/region1/skydns/test"}, + {Host: "server1", Group: "", Key: "b/dom/region1/skydns/test"}, + {Host: "server2", Group: "g2", Key: "a/subdom/dom/region1/skydns/test"}, + }, + ) + if len(sx) != 2 { + t.Fatalf("Failure to group fourth set: %v", sx) + } + + // Empty group on the highest level: include that one also, and the rest. + sx = Group( + []Service{ + {Host: "server1", Group: "g5", Key: "a/dom/region1/skydns/test"}, + {Host: "server1", Group: "", Key: "b/dom/region1/skydns/test"}, + {Host: "server2", Group: "g5", Key: "a/subdom/dom/region1/skydns/test"}, + }, + ) + if len(sx) != 3 { + t.Fatalf("Failure to group fifth set: %v", sx) + } + + // One group. + sx = Group( + []Service{ + {Host: "server1", Group: "g6", Key: "a/dom/region1/skydns/test"}, + }, + ) + if len(sx) != 1 { + t.Fatalf("Failure to group sixth set: %v", sx) + } + + // No group, once service + sx = Group( + []Service{ + {Host: "server1", Key: "a/dom/region1/skydns/test"}, + }, + ) + if len(sx) != 1 { + t.Fatalf("Failure to group seventh set: %v", sx) + } +} diff --git a/plugin/etcd/msg/type.go b/plugin/etcd/msg/type.go new file mode 100644 index 0000000..a300eac --- /dev/null +++ b/plugin/etcd/msg/type.go @@ -0,0 +1,35 @@ +package msg + +import ( + "net" + + "github.com/miekg/dns" +) + +// HostType returns the DNS type of what is encoded in the Service Host field. We're reusing +// dns.TypeXXX to not reinvent a new set of identifiers. +// +// dns.TypeA: the service's Host field contains an A record. +// dns.TypeAAAA: the service's Host field contains an AAAA record. +// dns.TypeCNAME: the service's Host field contains a name. +// +// Note that a service can double/triple as a TXT record or MX record. +func (s *Service) HostType() (what uint16, normalized net.IP) { + ip := net.ParseIP(s.Host) + + switch { + case ip == nil: + if len(s.Text) == 0 { + return dns.TypeCNAME, nil + } + return dns.TypeTXT, nil + + case ip.To4() != nil: + return dns.TypeA, ip.To4() + + case ip.To4() == nil: + return dns.TypeAAAA, ip.To16() + } + // This should never be reached. + return dns.TypeNone, nil +} diff --git a/plugin/etcd/msg/type_test.go b/plugin/etcd/msg/type_test.go new file mode 100644 index 0000000..721f5a8 --- /dev/null +++ b/plugin/etcd/msg/type_test.go @@ -0,0 +1,30 @@ +package msg + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestType(t *testing.T) { + tests := []struct { + serv Service + expectedType uint16 + }{ + {Service{Host: "example.org"}, dns.TypeCNAME}, + {Service{Host: "127.0.0.1"}, dns.TypeA}, + {Service{Host: "2000::3"}, dns.TypeAAAA}, + {Service{Host: "2000..3"}, dns.TypeCNAME}, + {Service{Host: "127.0.0.257"}, dns.TypeCNAME}, + {Service{Host: "127.0.0.252", Mail: true}, dns.TypeA}, + {Service{Host: "127.0.0.252", Mail: true, Text: "a"}, dns.TypeA}, + {Service{Host: "127.0.0.254", Mail: false, Text: "a"}, dns.TypeA}, + } + + for i, tc := range tests { + what, _ := tc.serv.HostType() + if what != tc.expectedType { + t.Errorf("Test %d: Expected what %v, but got %v", i, tc.expectedType, what) + } + } +} diff --git a/plugin/etcd/multi_test.go b/plugin/etcd/multi_test.go new file mode 100644 index 0000000..7993a25 --- /dev/null +++ b/plugin/etcd/multi_test.go @@ -0,0 +1,60 @@ +//go:build etcd + +package etcd + +import ( + "testing" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestMultiLookup(t *testing.T) { + etc := newEtcdPlugin() + etc.Zones = []string{"skydns.test.", "miek.nl."} + etc.Next = test.ErrorHandler() + + for _, serv := range servicesMulti { + set(t, etc, serv.Key, 0, serv) + defer delete(t, etc, serv.Key) + } + for _, tc := range dnsTestCasesMulti { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := etc.ServeDNS(ctxt, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +// Note the key is encoded as DNS name, while in "reality" it is a etcd path. +var servicesMulti = []*msg.Service{ + {Host: "dev.server1", Port: 8080, Key: "a.server1.dev.region1.skydns.test."}, + {Host: "dev.server1", Port: 8080, Key: "a.server1.dev.region1.miek.nl."}, + {Host: "dev.server1", Port: 8080, Key: "a.server1.dev.region1.example.org."}, +} + +var dnsTestCasesMulti = []test.Case{ + { + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeSRV, + Answer: []dns.RR{test.SRV("a.server1.dev.region1.skydns.test. 300 SRV 10 100 8080 dev.server1.")}, + }, + { + Qname: "a.server1.dev.region1.miek.nl.", Qtype: dns.TypeSRV, + Answer: []dns.RR{test.SRV("a.server1.dev.region1.miek.nl. 300 SRV 10 100 8080 dev.server1.")}, + }, + { + Qname: "a.server1.dev.region1.example.org.", Qtype: dns.TypeSRV, Rcode: dns.RcodeServerFailure, + }, +} diff --git a/plugin/etcd/other_test.go b/plugin/etcd/other_test.go new file mode 100644 index 0000000..a71260f --- /dev/null +++ b/plugin/etcd/other_test.go @@ -0,0 +1,138 @@ +//go:build etcd + +// tests mx and txt records + +package etcd + +import ( + "fmt" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestOtherLookup(t *testing.T) { + etc := newEtcdPlugin() + + for _, serv := range servicesOther { + set(t, etc, serv.Key, 0, serv) + defer delete(t, etc, serv.Key) + } + for _, tc := range dnsTestCasesOther { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := etc.ServeDNS(ctxt, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + continue + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +// Note the key is encoded as DNS name, while in "reality" it is a etcd path. +var servicesOther = []*msg.Service{ + {Host: "dev.server1", Port: 8080, Key: "a.server1.dev.region1.skydns.test."}, + + // mx + {Host: "mx.skydns.test", Priority: 50, Mail: true, Key: "a.mail.skydns.test."}, + {Host: "mx.miek.nl", Priority: 50, Mail: true, Key: "b.mail.skydns.test."}, + {Host: "a.ipaddr.skydns.test", Priority: 30, Mail: true, Key: "a.mx.skydns.test."}, + + {Host: "a.ipaddr.skydns.test", Mail: true, Key: "a.mx2.skydns.test."}, + {Host: "b.ipaddr.skydns.test", Mail: true, Key: "b.mx2.skydns.test."}, + + {Host: "a.ipaddr.skydns.test", Priority: 20, Mail: true, Key: "a.mx3.skydns.test."}, + {Host: "a.ipaddr.skydns.test", Priority: 30, Mail: true, Key: "b.mx3.skydns.test."}, + + {Host: "172.16.1.1", Key: "a.ipaddr.skydns.test."}, + {Host: "172.16.1.2", Key: "b.ipaddr.skydns.test."}, + + // txt + {Text: "abc", Key: "a1.txt.skydns.test."}, + {Text: "abc abc", Key: "a2.txt.skydns.test."}, + // txt sizes + {Text: strings.Repeat("0", 400), Key: "large400.skydns.test."}, + {Text: strings.Repeat("0", 600), Key: "large600.skydns.test."}, + {Text: strings.Repeat("0", 2000), Key: "large2000.skydns.test."}, + + // duplicate ip address + {Host: "10.11.11.10", Key: "http.multiport.http.skydns.test.", Port: 80}, + {Host: "10.11.11.10", Key: "https.multiport.http.skydns.test.", Port: 443}, +} + +var dnsTestCasesOther = []test.Case{ + // MX Tests + { + // NODATA as this is not an Mail: true record. + Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeMX, + Ns: []dns.RR{ + test.SOA("skydns.test. 30 SOA ns.dns.skydns.test. hostmaster.skydns.test. 0 0 0 0 0"), + }, + }, + { + Qname: "a.mail.skydns.test.", Qtype: dns.TypeMX, + Answer: []dns.RR{test.MX("a.mail.skydns.test. 300 IN MX 50 mx.skydns.test.")}, + Extra: []dns.RR{ + test.A("a.ipaddr.skydns.test. 300 IN A 172.16.1.1"), + test.CNAME("mx.skydns.test. 300 IN CNAME a.ipaddr.skydns.test."), + }, + }, + { + Qname: "mx2.skydns.test.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("mx2.skydns.test. 300 IN MX 10 a.ipaddr.skydns.test."), + test.MX("mx2.skydns.test. 300 IN MX 10 b.ipaddr.skydns.test."), + }, + Extra: []dns.RR{ + test.A("a.ipaddr.skydns.test. 300 A 172.16.1.1"), + test.A("b.ipaddr.skydns.test. 300 A 172.16.1.2"), + }, + }, + // different priority, same host + { + Qname: "mx3.skydns.test.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("mx3.skydns.test. 300 IN MX 20 a.ipaddr.skydns.test."), + test.MX("mx3.skydns.test. 300 IN MX 30 a.ipaddr.skydns.test."), + }, + Extra: []dns.RR{ + test.A("a.ipaddr.skydns.test. 300 A 172.16.1.1"), + }, + }, + // Txt + { + Qname: "a1.txt.skydns.test.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT("a1.txt.skydns.test. 300 IN TXT \"abc\""), + }, + }, + { + Qname: "a2.txt.skydns.test.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT("a2.txt.skydns.test. 300 IN TXT \"abc abc\""), + }, + }, + // Large txt less than 512 + { + Qname: "large400.skydns.test.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(fmt.Sprintf("large400.skydns.test. 300 IN TXT \"%s\"", strings.Repeat("0", 400))), + }, + }, + // Duplicate IP address test + { + Qname: "multiport.http.skydns.test.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A("multiport.http.skydns.test. 300 IN A 10.11.11.10")}, + }, +} diff --git a/plugin/etcd/setup.go b/plugin/etcd/setup.go new file mode 100644 index 0000000..ab6c4b7 --- /dev/null +++ b/plugin/etcd/setup.go @@ -0,0 +1,124 @@ +package etcd + +import ( + "crypto/tls" + "path/filepath" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + mwtls "github.com/coredns/coredns/plugin/pkg/tls" + "github.com/coredns/coredns/plugin/pkg/upstream" + + etcdcv3 "go.etcd.io/etcd/client/v3" +) + +func init() { plugin.Register("etcd", setup) } + +func setup(c *caddy.Controller) error { + e, err := etcdParse(c) + if err != nil { + return plugin.Error("etcd", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + e.Next = next + return e + }) + + return nil +} + +func etcdParse(c *caddy.Controller) (*Etcd, error) { + config := dnsserver.GetConfig(c) + etc := Etcd{PathPrefix: "skydns"} + var ( + tlsConfig *tls.Config + err error + endpoints = []string{defaultEndpoint} + username string + password string + ) + + etc.Upstream = upstream.New() + + if c.Next() { + etc.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + for c.NextBlock() { + switch c.Val() { + case "stubzones": + // ignored, remove later. + case "fallthrough": + etc.Fall.SetZonesFromArgs(c.RemainingArgs()) + case "debug": + /* it is a noop now */ + case "path": + if !c.NextArg() { + return &Etcd{}, c.ArgErr() + } + etc.PathPrefix = c.Val() + case "endpoint": + args := c.RemainingArgs() + if len(args) == 0 { + return &Etcd{}, c.ArgErr() + } + endpoints = args + case "upstream": + // remove soon + c.RemainingArgs() + case "tls": // cert key cacertfile + args := c.RemainingArgs() + for i := range args { + if !filepath.IsAbs(args[i]) && config.Root != "" { + args[i] = filepath.Join(config.Root, args[i]) + } + } + tlsConfig, err = mwtls.NewTLSConfigFromArgs(args...) + if err != nil { + return &Etcd{}, err + } + case "credentials": + args := c.RemainingArgs() + if len(args) == 0 { + return &Etcd{}, c.ArgErr() + } + if len(args) != 2 { + return &Etcd{}, c.Errf("credentials requires 2 arguments, username and password") + } + username, password = args[0], args[1] + default: + if c.Val() != "}" { + return &Etcd{}, c.Errf("unknown property '%s'", c.Val()) + } + } + } + client, err := newEtcdClient(endpoints, tlsConfig, username, password) + if err != nil { + return &Etcd{}, err + } + etc.Client = client + etc.endpoints = endpoints + + return &etc, nil + } + return &Etcd{}, nil +} + +func newEtcdClient(endpoints []string, cc *tls.Config, username, password string) (*etcdcv3.Client, error) { + etcdCfg := etcdcv3.Config{ + Endpoints: endpoints, + TLS: cc, + DialKeepAliveTime: etcdTimeout, + } + if username != "" && password != "" { + etcdCfg.Username = username + etcdCfg.Password = password + } + cli, err := etcdcv3.New(etcdCfg) + if err != nil { + return nil, err + } + return cli, nil +} + +const defaultEndpoint = "http://localhost:2379" diff --git a/plugin/etcd/setup_test.go b/plugin/etcd/setup_test.go new file mode 100644 index 0000000..4922641 --- /dev/null +++ b/plugin/etcd/setup_test.go @@ -0,0 +1,118 @@ +//go:build etcd + +package etcd + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupEtcd(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedPath string + expectedEndpoint []string + expectedErrContent string // substring from the expected error. Empty for positive cases. + username string + password string + }{ + // positive + { + `etcd`, false, "skydns", []string{"http://localhost:2379"}, "", "", "", + }, + { + `etcd { + endpoint http://localhost:2379 http://localhost:3379 http://localhost:4379 + +}`, false, "skydns", []string{"http://localhost:2379", "http://localhost:3379", "http://localhost:4379"}, "", "", "", + }, + { + `etcd skydns.local { + endpoint localhost:300 +} +`, false, "skydns", []string{"localhost:300"}, "", "", "", + }, + // negative + { + `etcd { + endpoints localhost:300 +} +`, true, "", []string{""}, "unknown property 'endpoints'", "", "", + }, + // with valid credentials + { + `etcd { + endpoint http://localhost:2379 + credentials username password + } + `, false, "skydns", []string{"http://localhost:2379"}, "", "username", "password", + }, + // with credentials, missing password + { + `etcd { + endpoint http://localhost:2379 + credentials username + } + `, true, "skydns", []string{"http://localhost:2379"}, "credentials requires 2 arguments", "username", "", + }, + // with credentials, missing username and password + { + `etcd { + endpoint http://localhost:2379 + credentials + } + `, true, "skydns", []string{"http://localhost:2379"}, "Wrong argument count", "", "", + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + etcd, err := etcdParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + continue + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err.Error(), test.input) + continue + } + } + + if !test.shouldErr && etcd.PathPrefix != test.expectedPath { + t.Errorf("Etcd not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedPath, etcd.PathPrefix) + } + if !test.shouldErr { + if len(etcd.endpoints) != len(test.expectedEndpoint) { + t.Errorf("Etcd not correctly set for input %s. Expected: '%+v', actual: '%+v'", test.input, test.expectedEndpoint, etcd.endpoints) + } + for i, endpoint := range etcd.endpoints { + if endpoint != test.expectedEndpoint[i] { + t.Errorf("Etcd not correctly set for input %s. Expected: '%+v', actual: '%+v'", test.input, test.expectedEndpoint, etcd.endpoints) + } + } + } + + if !test.shouldErr { + if test.username != "" { + if etcd.Client.Username != test.username { + t.Errorf("Etcd username not correctly set for input %s. Expected: '%+v', actual: '%+v'", test.input, test.username, etcd.Client.Username) + } + } + if test.password != "" { + if etcd.Client.Password != test.password { + t.Errorf("Etcd password not correctly set for input %s. Expected: '%+v', actual: '%+v'", test.input, test.password, etcd.Client.Password) + } + } + } + } +} diff --git a/plugin/etcd/xfr.go b/plugin/etcd/xfr.go new file mode 100644 index 0000000..87a4d78 --- /dev/null +++ b/plugin/etcd/xfr.go @@ -0,0 +1,17 @@ +package etcd + +import ( + "time" + + "github.com/coredns/coredns/request" +) + +// Serial returns the serial number to use. +func (e *Etcd) Serial(state request.Request) uint32 { + return uint32(time.Now().Unix()) +} + +// MinTTL returns the minimal TTL. +func (e *Etcd) MinTTL(state request.Request) uint32 { + return 30 +} diff --git a/plugin/file/README.md b/plugin/file/README.md new file mode 100644 index 0000000..d1bd425 --- /dev/null +++ b/plugin/file/README.md @@ -0,0 +1,112 @@ +# file + +## Name + +*file* - enables serving zone data from an RFC 1035-style master file. + +## Description + +The *file* plugin is used for an "old-style" DNS server. It serves from a preloaded file that exists +on disk contained RFC 1035 styled data. If the zone file contains signatures (i.e., is signed using +DNSSEC), correct DNSSEC answers are returned. Only NSEC is supported! If you use this setup *you* +are responsible for re-signing the zonefile. + +## Syntax + +~~~ +file DBFILE [ZONES...] +~~~ + +* **DBFILE** the database file to read and parse. If the path is relative, the path from the *root* + plugin will be prepended to it. +* **ZONES** zones it should be authoritative for. If empty, the zones from the configuration block + are used. + +If you want to round-robin A and AAAA responses look at the *loadbalance* plugin. + +~~~ +file DBFILE [ZONES... ] { + reload DURATION +} +~~~ + +* `reload` interval to perform a reload of the zone if the SOA version changes. Default is one minute. + Value of `0` means to not scan for changes and reload. For example, `30s` checks the zonefile every 30 seconds + and reloads the zone when serial changes. + +If you need outgoing zone transfers, take a look at the *transfer* plugin. + +## Examples + +Load the `example.org` zone from `db.example.org` and allow transfers to the internet, but send +notifies to 10.240.1.1 + +~~~ corefile +example.org { + file db.example.org + transfer { + to * 10.240.1.1 + } +} +~~~ + +Where `db.example.org` would contain RRSets (<https://tools.ietf.org/html/rfc7719#section-4>) in the +(text) presentation format from RFC 1035: + +~~~ +$ORIGIN example.org. +@ 3600 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2017042745 7200 3600 1209600 3600 + 3600 IN NS a.iana-servers.net. + 3600 IN NS b.iana-servers.net. + +www IN A 127.0.0.1 + IN AAAA ::1 +~~~ + + +Or use a single zone file for multiple zones: + +~~~ corefile +. { + file example.org.signed example.org example.net + transfer example.org example.net { + to * 10.240.1.1 + } +} +~~~ + +Note that if you have a configuration like the following you may run into a problem of the origin +not being correctly recognized: + +~~~ corefile +. { + file db.example.org +} +~~~ + +We omit the origin for the file `db.example.org`, so this references the zone in the server block, +which, in this case, is the root zone. Any contents of `db.example.org` will then read with that +origin set; this may or may not do what you want. +It's better to be explicit here and specify the correct origin. This can be done in two ways: + +~~~ corefile +. { + file db.example.org example.org +} +~~~ + +Or + +~~~ corefile +example.org { + file db.example.org +} +~~~ + +## See Also + +See the *loadbalance* plugin if you need simple record shuffling. And the *transfer* plugin for zone +transfers. Lastly the *root* plugin can help you specify the location of the zone files. + +See [RFC 1035](https://www.rfc-editor.org/rfc/rfc1035.txt) for more info on how to structure zone +files. diff --git a/plugin/file/apex_test.go b/plugin/file/apex_test.go new file mode 100644 index 0000000..2108543 --- /dev/null +++ b/plugin/file/apex_test.go @@ -0,0 +1,45 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +const exampleApexOnly = `$ORIGIN example.com. +@ IN SOA ns1.example.com. admin.example.com. ( + 2005011437 ; Serial + 1200 ; Refresh + 144 ; Retry + 1814400 ; Expire + 2h ) ; Minimum +@ IN NS ns1.example.com. +` + +func TestLookupApex(t *testing.T) { + // this tests a zone with *only* an apex. The behavior here is wrong, we should return NODATA, but we do a NXDOMAIN. + // Adding this test to document this. Note a zone that doesn't have any data is pretty useless anyway, so rather than + // fix this with an entirely new branch in lookup.go, just live with it. + zone, err := Parse(strings.NewReader(exampleApexOnly), "example.com.", "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{"example.com.": zone}, Names: []string{"example.com."}}} + ctx := context.TODO() + + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := fm.ServeDNS(ctx, rec, m); err != nil { + t.Errorf("Expected no error, got %v", err) + } + if rec.Msg.Rcode != dns.RcodeNameError { // Should be RcodeSuccess in a perfect world. + t.Errorf("Expected rcode %d, got %d", dns.RcodeNameError, rec.Msg.Rcode) + } +} diff --git a/plugin/file/closest.go b/plugin/file/closest.go new file mode 100644 index 0000000..5059194 --- /dev/null +++ b/plugin/file/closest.go @@ -0,0 +1,23 @@ +package file + +import ( + "github.com/coredns/coredns/plugin/file/tree" + + "github.com/miekg/dns" +) + +// ClosestEncloser returns the closest encloser for qname. +func (z *Zone) ClosestEncloser(qname string) (*tree.Elem, bool) { + offset, end := dns.NextLabel(qname, 0) + for !end { + elem, _ := z.Tree.Search(qname) + if elem != nil { + return elem, true + } + qname = qname[offset:] + + offset, end = dns.NextLabel(qname, offset) + } + + return z.Tree.Search(z.origin) +} diff --git a/plugin/file/closest_test.go b/plugin/file/closest_test.go new file mode 100644 index 0000000..40c04ff --- /dev/null +++ b/plugin/file/closest_test.go @@ -0,0 +1,38 @@ +package file + +import ( + "strings" + "testing" +) + +func TestClosestEncloser(t *testing.T) { + z, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + tests := []struct { + in, out string + }{ + {"miek.nl.", "miek.nl."}, + {"www.miek.nl.", "www.miek.nl."}, + + {"blaat.miek.nl.", "miek.nl."}, + {"blaat.www.miek.nl.", "www.miek.nl."}, + {"www.blaat.miek.nl.", "miek.nl."}, + {"blaat.a.miek.nl.", "a.miek.nl."}, + } + + for _, tc := range tests { + ce, _ := z.ClosestEncloser(tc.in) + if ce == nil { + if z.origin != tc.out { + t.Errorf("Expected ce to be %s for %s, got %s", tc.out, tc.in, ce.Name()) + } + continue + } + if ce.Name() != tc.out { + t.Errorf("Expected ce to be %s for %s, got %s", tc.out, tc.in, ce.Name()) + } + } +} diff --git a/plugin/file/delegation_test.go b/plugin/file/delegation_test.go new file mode 100644 index 0000000..a6da621 --- /dev/null +++ b/plugin/file/delegation_test.go @@ -0,0 +1,228 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var delegationTestCases = []test.Case{ + { + Qname: "a.delegated.miek.nl.", Qtype: dns.TypeTXT, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "delegated.miek.nl.", Qtype: dns.TypeNS, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "foo.delegated.miek.nl.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "foo.delegated.miek.nl.", Qtype: dns.TypeTXT, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "foo.delegated.miek.nl.", Qtype: dns.TypeSOA, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeSOA, + Answer: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + Ns: miekAuth, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeAAAA, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, +} + +var secureDelegationTestCases = []test.Case{ + { + Qname: "a.delegated.example.org.", Qtype: dns.TypeTXT, Do: true, + Ns: []dns.RR{ + test.DS("delegated.example.org. 1800 IN DS 10056 5 1 EE72CABD1927759CDDA92A10DBF431504B9E1F13"), + test.DS("delegated.example.org. 1800 IN DS 10056 5 2 E4B05F87725FA86D9A64F1E53C3D0E6250946599DFE639C45955B0ED416CDDFA"), + test.NS("delegated.example.org. 1800 IN NS a.delegated.example.org."), + test.NS("delegated.example.org. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.RRSIG("delegated.example.org. 1800 IN RRSIG DS 13 3 1800 20161129153240 20161030153240 49035 example.org. rlNNzcUmtbjLSl02ZzQGUbWX75yCUx0Mug1jHtKVqRq1hpPE2S3863tIWSlz+W9wz4o19OI4jbznKKqk+DGKog=="), + }, + Extra: []dns.RR{ + test.A("a.delegated.example.org. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.example.org. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "delegated.example.org.", Qtype: dns.TypeNS, Do: true, + Ns: []dns.RR{ + test.DS("delegated.example.org. 1800 IN DS 10056 5 1 EE72CABD1927759CDDA92A10DBF431504B9E1F13"), + test.DS("delegated.example.org. 1800 IN DS 10056 5 2 E4B05F87725FA86D9A64F1E53C3D0E6250946599DFE639C45955B0ED416CDDFA"), + test.NS("delegated.example.org. 1800 IN NS a.delegated.example.org."), + test.NS("delegated.example.org. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.RRSIG("delegated.example.org. 1800 IN RRSIG DS 13 3 1800 20161129153240 20161030153240 49035 example.org. rlNNzcUmtbjLSl02ZzQGUbWX75yCUx0Mug1jHtKVqRq1hpPE2S3863tIWSlz+W9wz4o19OI4jbznKKqk+DGKog=="), + }, + Extra: []dns.RR{ + test.A("a.delegated.example.org. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.example.org. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "foo.delegated.example.org.", Qtype: dns.TypeA, Do: true, + Ns: []dns.RR{ + test.DS("delegated.example.org. 1800 IN DS 10056 5 1 EE72CABD1927759CDDA92A10DBF431504B9E1F13"), + test.DS("delegated.example.org. 1800 IN DS 10056 5 2 E4B05F87725FA86D9A64F1E53C3D0E6250946599DFE639C45955B0ED416CDDFA"), + test.NS("delegated.example.org. 1800 IN NS a.delegated.example.org."), + test.NS("delegated.example.org. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.RRSIG("delegated.example.org. 1800 IN RRSIG DS 13 3 1800 20161129153240 20161030153240 49035 example.org. rlNNzcUmtbjLSl02ZzQGUbWX75yCUx0Mug1jHtKVqRq1hpPE2S3863tIWSlz+W9wz4o19OI4jbznKKqk+DGKog=="), + }, + Extra: []dns.RR{ + test.A("a.delegated.example.org. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.example.org. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "foo.delegated.example.org.", Qtype: dns.TypeDS, Do: true, + Ns: []dns.RR{ + test.DS("delegated.example.org. 1800 IN DS 10056 5 1 EE72CABD1927759CDDA92A10DBF431504B9E1F13"), + test.DS("delegated.example.org. 1800 IN DS 10056 5 2 E4B05F87725FA86D9A64F1E53C3D0E6250946599DFE639C45955B0ED416CDDFA"), + test.NS("delegated.example.org. 1800 IN NS a.delegated.example.org."), + test.NS("delegated.example.org. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.RRSIG("delegated.example.org. 1800 IN RRSIG DS 13 3 1800 20161129153240 20161030153240 49035 example.org. rlNNzcUmtbjLSl02ZzQGUbWX75yCUx0Mug1jHtKVqRq1hpPE2S3863tIWSlz+W9wz4o19OI4jbznKKqk+DGKog=="), + }, + Extra: []dns.RR{ + test.A("a.delegated.example.org. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.example.org. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "delegated.example.org.", Qtype: dns.TypeDS, Do: true, + Answer: []dns.RR{ + test.DS("delegated.example.org. 1800 IN DS 10056 5 1 EE72CABD1927759CDDA92A10DBF431504B9E1F13"), + test.DS("delegated.example.org. 1800 IN DS 10056 5 2 E4B05F87725FA86D9A64F1E53C3D0E6250946599DFE639C45955B0ED416CDDFA"), + test.RRSIG("delegated.example.org. 1800 IN RRSIG DS 13 3 1800 20161129153240 20161030153240 49035 example.org. rlNNzcUmtbjLSl02ZzQGUbWX75yCUx0Mug1jHtKVqRq1hpPE2S3863tIWSlz+W9wz4o19OI4jbznKKqk+DGKog=="), + }, + Ns: []dns.RR{ + test.NS("example.org. 1800 IN NS a.iana-servers.net."), + test.NS("example.org. 1800 IN NS b.iana-servers.net."), + test.RRSIG("example.org. 1800 IN RRSIG NS 13 2 1800 20161129153240 20161030153240 49035 example.org. llrHoIuw="), + }, + }, +} + +var miekAuth = []dns.RR{ + test.NS("miek.nl. 1800 IN NS ext.ns.whyscream.net."), + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.NS("miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.NS("miek.nl. 1800 IN NS omval.tednet.nl."), +} + +func TestLookupDelegation(t *testing.T) { + testDelegation(t, dbMiekNLDelegation, testzone, delegationTestCases) +} + +func TestLookupSecureDelegation(t *testing.T) { + testDelegation(t, exampleOrgSigned, "example.org.", secureDelegationTestCases) +} + +func testDelegation(t *testing.T, z, origin string, testcases []test.Case) { + zone, err := Parse(strings.NewReader(z), origin, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{origin: zone}, Names: []string{origin}}} + ctx := context.TODO() + + for _, tc := range testcases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %q", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +const dbMiekNLDelegation = ` +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + IN NS ns-ext.nlnetlabs.nl. + IN NS omval.tednet.nl. + IN NS ext.ns.whyscream.net. + + IN MX 1 aspmx.l.google.com. + IN MX 5 alt1.aspmx.l.google.com. + IN MX 5 alt2.aspmx.l.google.com. + IN MX 10 aspmx2.googlemail.com. + IN MX 10 aspmx3.googlemail.com. + +delegated IN NS a.delegated + IN NS ns-ext.nlnetlabs.nl. + +a.delegated IN TXT "obscured" + IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 + +a IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +www IN CNAME a +archive IN CNAME a` diff --git a/plugin/file/delete_test.go b/plugin/file/delete_test.go new file mode 100644 index 0000000..26ee64e --- /dev/null +++ b/plugin/file/delete_test.go @@ -0,0 +1,65 @@ +package file + +import ( + "bytes" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/file/tree" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +/* +Create a zone with: + + apex + / + a MX + a A + +Test that: we create the proper tree and that delete +deletes the correct elements +*/ + +var tz = NewZone("example.org.", "db.example.org.") + +type treebuf struct { + *bytes.Buffer +} + +func (t *treebuf) printFunc(e *tree.Elem, rrs map[uint16][]dns.RR) error { + fmt.Fprintf(t.Buffer, "%v\n", rrs) // should be fixed order in new go versions. + return nil +} + +func TestZoneInsertAndDelete(t *testing.T) { + tz.Insert(test.SOA("example.org. IN SOA 1 2 3 4 5")) + + if x := tz.Apex.SOA.Header().Name; x != "example.org." { + t.Errorf("Failed to insert SOA, expected %s, git %s", "example.org.", x) + } + + // Insert two RRs and then remove one. + tz.Insert(test.A("a.example.org. IN A 127.0.0.1")) + tz.Insert(test.MX("a.example.org. IN MX 10 mx.example.org.")) + + tz.Delete(test.MX("a.example.org. IN MX 10 mx.example.org.")) + + tb := treebuf{new(bytes.Buffer)} + + tz.Walk(tb.printFunc) + if tb.String() != "map[1:[a.example.org.\t3600\tIN\tA\t127.0.0.1]]\n" { + t.Errorf("Expected 1 A record in tree, got %s", tb.String()) + } + + tz.Delete(test.A("a.example.org. IN A 127.0.0.1")) + + tb.Reset() + + tz.Walk(tb.printFunc) + if tb.String() != "" { + t.Errorf("Expected no record in tree, got %s", tb.String()) + } +} diff --git a/plugin/file/dname.go b/plugin/file/dname.go new file mode 100644 index 0000000..58351a3 --- /dev/null +++ b/plugin/file/dname.go @@ -0,0 +1,44 @@ +package file + +import ( + "github.com/coredns/coredns/plugin/pkg/dnsutil" + + "github.com/miekg/dns" +) + +// substituteDNAME performs the DNAME substitution defined by RFC 6672, +// assuming the QTYPE of the query is not DNAME. It returns an empty +// string if there is no match. +func substituteDNAME(qname, owner, target string) string { + if dns.IsSubDomain(owner, qname) && qname != owner { + labels := dns.SplitDomainName(qname) + labels = append(labels[0:len(labels)-dns.CountLabel(owner)], dns.SplitDomainName(target)...) + + return dnsutil.Join(labels...) + } + + return "" +} + +// synthesizeCNAME returns a CNAME RR pointing to the resulting name of +// the DNAME substitution. The owner name of the CNAME is the QNAME of +// the query and the TTL is the same as the corresponding DNAME RR. +// +// It returns nil if the DNAME substitution has no match. +func synthesizeCNAME(qname string, d *dns.DNAME) *dns.CNAME { + target := substituteDNAME(qname, d.Header().Name, d.Target) + if target == "" { + return nil + } + + r := new(dns.CNAME) + r.Hdr = dns.RR_Header{ + Name: qname, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: d.Header().Ttl, + } + r.Target = target + + return r +} diff --git a/plugin/file/dname_test.go b/plugin/file/dname_test.go new file mode 100644 index 0000000..cc70bb5 --- /dev/null +++ b/plugin/file/dname_test.go @@ -0,0 +1,300 @@ +package file + +/* +TODO(miek): move to test/ for full server testing + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// RFC 6672, Section 2.2. Assuming QTYPE != DNAME. +var dnameSubstitutionTestCases = []struct { + qname string + owner string + target string + expected string +}{ + {"com.", "example.com.", "example.net.", ""}, + {"example.com.", "example.com.", "example.net.", ""}, + {"a.example.com.", "example.com.", "example.net.", "a.example.net."}, + {"a.b.example.com.", "example.com.", "example.net.", "a.b.example.net."}, + {"ab.example.com.", "b.example.com.", "example.net.", ""}, + {"foo.example.com.", "example.com.", "example.net.", "foo.example.net."}, + {"a.x.example.com.", "x.example.com.", "example.net.", "a.example.net."}, + {"a.example.com.", "example.com.", "y.example.net.", "a.y.example.net."}, + {"cyc.example.com.", "example.com.", "example.com.", "cyc.example.com."}, + {"cyc.example.com.", "example.com.", "c.example.com.", "cyc.c.example.com."}, + {"shortloop.x.x.", "x.", ".", "shortloop.x."}, + {"shortloop.x.", "x.", ".", "shortloop."}, +} + +func TestDNAMESubstitution(t *testing.T) { + for i, tc := range dnameSubstitutionTestCases { + result := substituteDNAME(tc.qname, tc.owner, tc.target) + if result != tc.expected { + if result == "" { + result = "<no match>" + } + + t.Errorf("Case %d: Expected %s -> %s, got %v", i, tc.qname, tc.expected, result) + return + } + } +} + +var dnameTestCases = []test.Case{ + { + Qname: "dname.miek.nl.", Qtype: dns.TypeDNAME, + Answer: []dns.RR{ + test.DNAME("dname.miek.nl. 1800 IN DNAME test.miek.nl."), + }, + Ns: miekAuth, + }, + { + Qname: "dname.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("dname.miek.nl. 1800 IN A 127.0.0.1"), + }, + Ns: miekAuth, + }, + { + Qname: "dname.miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{}, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "a.dname.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("a.dname.miek.nl. 1800 IN CNAME a.test.miek.nl."), + test.A("a.test.miek.nl. 1800 IN A 139.162.196.78"), + test.DNAME("dname.miek.nl. 1800 IN DNAME test.miek.nl."), + }, + Ns: miekAuth, + }, + { + Qname: "www.dname.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("a.test.miek.nl. 1800 IN A 139.162.196.78"), + test.DNAME("dname.miek.nl. 1800 IN DNAME test.miek.nl."), + test.CNAME("www.dname.miek.nl. 1800 IN CNAME www.test.miek.nl."), + test.CNAME("www.test.miek.nl. 1800 IN CNAME a.test.miek.nl."), + }, + Ns: miekAuth, + }, +} + +func TestLookupDNAME(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekNLDNAME), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + + for _, tc := range dnameTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + test.SortAndCheck(t, resp, tc) + } +} + +var dnameDnssecTestCases = []test.Case{ + { + // We have no auth section, because the test zone does not have nameservers. + Qname: "ns.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("ns.example.org. 1800 IN A 127.0.0.1"), + }, + }, + { + Qname: "dname.example.org.", Qtype: dns.TypeDNAME, Do: true, + Answer: []dns.RR{ + test.DNAME("dname.example.org. 1800 IN DNAME test.example.org."), + test.RRSIG("dname.example.org. 1800 IN RRSIG DNAME 5 3 1800 20170702091734 20170602091734 54282 example.org. HvXtiBM="), + }, + }, + { + Qname: "a.dname.example.org.", Qtype: dns.TypeA, Do: true, + Answer: []dns.RR{ + test.CNAME("a.dname.example.org. 1800 IN CNAME a.test.example.org."), + test.DNAME("dname.example.org. 1800 IN DNAME test.example.org."), + test.RRSIG("dname.example.org. 1800 IN RRSIG DNAME 5 3 1800 20170702091734 20170602091734 54282 example.org. HvXtiBM="), + }, + }, +} + +func TestLookupDNAMEDNSSEC(t *testing.T) { + zone, err := Parse(strings.NewReader(dbExampleDNAMESigned), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{"example.org.": zone}, Names: []string{"example.org."}}} + ctx := context.TODO() + + for _, tc := range dnameDnssecTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + test.SortAndCheck(t, resp, tc) + } +} + +const dbMiekNLDNAME = ` +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + IN NS ns-ext.nlnetlabs.nl. + IN NS omval.tednet.nl. + IN NS ext.ns.whyscream.net. + +test IN MX 1 aspmx.l.google.com. + IN MX 5 alt1.aspmx.l.google.com. + IN MX 5 alt2.aspmx.l.google.com. + IN MX 10 aspmx2.googlemail.com. + IN MX 10 aspmx3.googlemail.com. +a.test IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +www.test IN CNAME a.test + +dname IN DNAME test +dname IN A 127.0.0.1 +a.dname IN A 127.0.0.1 +` + +const dbExampleDNAMESigned = ` +; File written on Fri Jun 2 10:17:34 2017 +; dnssec_signzone version 9.10.3-P4-Debian +example.org. 1800 IN SOA a.example.org. b.example.org. ( + 1282630057 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 14400 ; minimum (4 hours) + ) + 1800 RRSIG SOA 5 2 1800 ( + 20170702091734 20170602091734 54282 example.org. + mr5eQtFs1GubgwaCcqrpiF6Cgi822OkESPeV + X0OJYq3JzthJjHw8TfYAJWQ2yGqhlePHir9h + FT/uFZdYyytHq+qgIUbJ9IVCrq0gZISZdHML + Ry1DNffMR9CpD77KocOAUABfopcvH/3UGOHn + TFxkAr447zPaaoC68JYGxYLfZk8= ) + 1800 NS ns.example.org. + 1800 RRSIG NS 5 2 1800 ( + 20170702091734 20170602091734 54282 example.org. + McM4UdMxkscVQkJnnEbdqwyjpPgq5a/EuOLA + r2MvG43/cwOaWULiZoNzLi5Rjzhf+GTeVTan + jw6EsL3gEuYI1nznwlLQ04/G0XAHjbq5VvJc + rlscBD+dzf774yfaTjRNoeo2xTem6S7nyYPW + Y+1f6xkrsQPLYJfZ6VZ9QqyupBw= ) + 14400 NSEC dname.example.org. NS SOA RRSIG NSEC DNSKEY + 14400 RRSIG NSEC 5 2 14400 ( + 20170702091734 20170602091734 54282 example.org. + VT+IbjDFajM0doMKFipdX3+UXfCn3iHIxg5x + LElp4Q/YddTbX+6tZf53+EO+G8Kye3JDLwEl + o8VceijNeF3igZ+LiZuXCei5Qg/TJ7IAUnAO + xd85IWwEYwyKkKd6Z2kXbAN2pdcHE8EmboQd + wfTr9oyWhpZk1Z+pN8vdejPrG0M= ) + 1800 DNSKEY 256 3 5 ( + AwEAAczLlmTk5bMXUzpBo/Jta6MWSZYy3Nfw + gz8t/pkfSh4IlFF6vyXZhEqCeQsCBdD7ltkD + h5qd4A+nFrYOMwsi5XIjoHMlJN15xwFS9EgS + ZrZmuxePIEiYB5KccEf9JQMgM1t07Iu1FnrY + 02OuAqGWcO4tuyTLaK3QP4MLQOfAgKqf + ) ; ZSK; alg = RSASHA1; key id = 54282 + 1800 RRSIG DNSKEY 5 2 1800 ( + 20170702091734 20170602091734 54282 example.org. + MBgSRtZ6idJblLIHxZWpWL/1oqIwImb1mkl7 + hDFxqV6Hw19yLX06P7gcJEWiisdZBkVEfcOK + LeMJly05vgKfrMzLgIu2Ry4bL8AMKc8NMXBG + b1VDCEBW69P2omogj2KnORHDCZQr/BX9+wBU + 5rIMTTKlMSI5sT6ecJHHEymtiac= ) +dname.example.org. 1800 IN A 127.0.0.1 + 1800 RRSIG A 5 3 1800 ( + 20170702091734 20170602091734 54282 example.org. + LPCK2nLyDdGwvmzGLkUO2atEUjoc+aEspkC3 + keZCdXZaLnAwBH7dNAjvvXzzy0WrgWeiyDb4 + +rJ2N0oaKEZicM4QQDHKhugJblKbU5G4qTey + LSEaV3vvQnzGd0S6dCqnwfPj9czagFN7Zlf5 + DmLtdxx0aiDPCUpqT0+H/vuGPfk= ) + 1800 DNAME test.example.org. + 1800 RRSIG DNAME 5 3 1800 ( + 20170702091734 20170602091734 54282 example.org. + HvX79T1flWJ8H9/1XZjX6gz8rP/o2jbfPXJ9 + vC7ids/ZJilSReabLru4DCqcw1IV2DM/CZdE + tBnED/T2PJXvMut9tnYMrz+ZFPxoV6XyA3Z7 + bok3B0OuxizzAN2EXdol04VdbMHoWUzjQCzi + 0Ri12zLGRPzDepZ7FolgD+JtiBM= ) + 14400 NSEC a.dname.example.org. A DNAME RRSIG NSEC + 14400 RRSIG NSEC 5 3 14400 ( + 20170702091734 20170602091734 54282 example.org. + U3ZPYMUBJl3wF2SazQv/kBf6ec0CH+7n0Hr9 + w6lBKkiXz7P9WQzJDVnTHEZOrbDI6UetFGyC + 6qcaADCASZ9Wxc+riyK1Hl4ox+Y/CHJ97WHy + oS2X//vEf6qmbHQXin0WQtFdU/VCRYF40X5v + 8VfqOmrr8iKiEqXND8XNVf58mTw= ) +a.dname.example.org. 1800 IN A 127.0.0.1 + 1800 RRSIG A 5 4 1800 ( + 20170702091734 20170602091734 54282 example.org. + y7RHBWZwli8SJQ4BgTmdXmYS3KGHZ7AitJCx + zXFksMQtNoOfVEQBwnFqjAb8ezcV5u92h1gN + i1EcuxCFiElML1XFT8dK2GnlPAga9w3oIwd5 + wzW/YHcnR0P9lF56Sl7RoIt6+jJqOdRfixS6 + TDoLoXsNbOxQ+qV3B8pU2Tam204= ) + 14400 NSEC ns.example.org. A RRSIG NSEC + 14400 RRSIG NSEC 5 4 14400 ( + 20170702091734 20170602091734 54282 example.org. + Tmu27q3+xfONSZZtZLhejBUVtEw+83ZU1AFb + Rsxctjry/x5r2JSxw/sgSAExxX/7tx/okZ8J + oJqtChpsr91Kiw3eEBgINi2lCYIpMJlW4cWz + 8bYlHfR81VsKYgy/cRgrq1RRvBoJnw+nwSty + mKPIvUtt67LAvLxJheSCEMZLCKI= ) +ns.example.org. 1800 IN A 127.0.0.1 + 1800 RRSIG A 5 3 1800 ( + 20170702091734 20170602091734 54282 example.org. + mhi1SGaaAt+ndQEg5uKWKCH0HMzaqh/9dUK3 + p2wWMBrLbTZrcWyz10zRnvehicXDCasbBrer + ZpDQnz5AgxYYBURvdPfUzx1XbNuRJRE4l5PN + CEUTlTWcqCXnlSoPKEJE5HRf7v0xg2BrBUfM + 4mZnW2bFLwjrRQ5mm/mAmHmTROk= ) + 14400 NSEC example.org. A RRSIG NSEC + 14400 RRSIG NSEC 5 3 14400 ( + 20170702091734 20170602091734 54282 example.org. + loHcdjX+NIWLAkUDfPSy2371wrfUvrBQTfMO + 17eO2Y9E/6PE935NF5bjQtZBRRghyxzrFJhm + vY1Ad5ZTb+NLHvdSWbJQJog+eCc7QWp64WzR + RXpMdvaE6ZDwalWldLjC3h8QDywDoFdndoRY + eHOsmTvvtWWqtO6Fa5A8gmHT5HA= ) +` +*/ diff --git a/plugin/file/dnssec_test.go b/plugin/file/dnssec_test.go new file mode 100644 index 0000000..7292523 --- /dev/null +++ b/plugin/file/dnssec_test.go @@ -0,0 +1,350 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// All OPT RR are added in server.go, so we don't specify them in the unit tests. +var dnssecTestCases = []test.Case{ + { + Qname: "miek.nl.", Qtype: dns.TypeSOA, Do: true, + Answer: []dns.RR{ + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. FIrzy07acBbtyQczy1dc="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + Ns: auth, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeAAAA, Do: true, + Answer: []dns.RR{ + test.AAAA("miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + test.RRSIG("miek.nl. 1800 IN RRSIG AAAA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. SsRT="), + }, + Ns: auth, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeNS, Do: true, + Answer: []dns.RR{ + test.NS("miek.nl. 1800 IN NS ext.ns.whyscream.net."), + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.NS("miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.NS("miek.nl. 1800 IN NS omval.tednet.nl."), + test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160426031301 20160327031301 12051 miek.nl. ZLtsQhwaz+lHfNpztFoR1Vxs="), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeMX, Do: true, + Answer: []dns.RR{ + test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 1800 IN MX 10 aspmx2.googlemail.com."), + test.MX("miek.nl. 1800 IN MX 10 aspmx3.googlemail.com."), + test.MX("miek.nl. 1800 IN MX 5 alt1.aspmx.l.google.com."), + test.MX("miek.nl. 1800 IN MX 5 alt2.aspmx.l.google.com."), + test.RRSIG("miek.nl. 1800 IN RRSIG MX 8 2 1800 20160426031301 20160327031301 12051 miek.nl. kLqG+iOr="), + }, + Ns: auth, + }, + { + Qname: "www.miek.nl.", Qtype: dns.TypeA, Do: true, + Answer: []dns.RR{ + test.A("a.miek.nl. 1800 IN A 139.162.196.78"), + test.RRSIG("a.miek.nl. 1800 IN RRSIG A 8 3 1800 20160426031301 20160327031301 12051 miek.nl. lxLotCjWZ3kihTxk="), + test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."), + test.RRSIG("www.miek.nl. 1800 RRSIG CNAME 8 3 1800 20160426031301 20160327031301 12051 miek.nl. NVZmMJaypS+wDL2Lar4Zw1zF"), + }, + Ns: auth, + }, + { + // NoData + Qname: "a.miek.nl.", Qtype: dns.TypeSRV, Do: true, + Ns: []dns.RR{ + test.NSEC("a.miek.nl. 14400 IN NSEC archive.miek.nl. A AAAA RRSIG NSEC"), + test.RRSIG("a.miek.nl. 14400 IN RRSIG NSEC 8 3 14400 20160426031301 20160327031301 12051 miek.nl. GqnF6cutipmSHEao="), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. FIrzy07acBbtyQczy1dc="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.NSEC("archive.miek.nl. 14400 IN NSEC go.dns.miek.nl. CNAME RRSIG NSEC"), + test.RRSIG("archive.miek.nl. 14400 IN RRSIG NSEC 8 3 14400 20160426031301 20160327031301 12051 miek.nl. jEpx8lcp4do5fWXg="), + test.NSEC("miek.nl. 14400 IN NSEC a.miek.nl. A NS SOA MX AAAA RRSIG NSEC DNSKEY"), + test.RRSIG("miek.nl. 14400 IN RRSIG NSEC 8 2 14400 20160426031301 20160327031301 12051 miek.nl. mFfc3r/9PSC1H6oSpdC"), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. FIrzy07acBbtyQczy1dc="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "b.blaat.miek.nl.", Qtype: dns.TypeA, Do: true, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.NSEC("archive.miek.nl. 14400 IN NSEC go.dns.miek.nl. CNAME RRSIG NSEC"), + test.RRSIG("archive.miek.nl. 14400 IN RRSIG NSEC 8 3 14400 20160426031301 20160327031301 12051 miek.nl. jEpx8lcp4do5fWXg="), + test.NSEC("miek.nl. 14400 IN NSEC a.miek.nl. A NS SOA MX AAAA RRSIG NSEC DNSKEY"), + test.RRSIG("miek.nl. 14400 IN RRSIG NSEC 8 2 14400 20160426031301 20160327031301 12051 miek.nl. mFfc3r/9PSC1H6oSpdC"), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. FIrzy07acBbtyQczy1dc="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "b.a.miek.nl.", Qtype: dns.TypeA, Do: true, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + // dedupped NSEC, because 1 nsec tells all + test.NSEC("a.miek.nl. 14400 IN NSEC archive.miek.nl. A AAAA RRSIG NSEC"), + test.RRSIG("a.miek.nl. 14400 IN RRSIG NSEC 8 3 14400 20160426031301 20160327031301 12051 miek.nl. GqnF6cut/RRGPQ1QGQE1ipmSHEao="), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. FIrzy07acBbtyQczy1dc="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, +} + +var auth = []dns.RR{ + test.NS("miek.nl. 1800 IN NS ext.ns.whyscream.net."), + test.NS("miek.nl. 1800 IN NS linode.atoom.net."), + test.NS("miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.NS("miek.nl. 1800 IN NS omval.tednet.nl."), + test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160426031301 20160327031301 12051 miek.nl. ZLtsQhwazbqSpztFoR1Vxs="), +} + +func TestLookupDNSSEC(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekNLSigned), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + + for _, tc := range dnssecTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +func BenchmarkFileLookupDNSSEC(b *testing.B) { + zone, err := Parse(strings.NewReader(dbMiekNLSigned), testzone, "stdin", 0) + if err != nil { + return + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + tc := test.Case{ + Qname: "b.miek.nl.", Qtype: dns.TypeA, Do: true, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.NSEC("archive.miek.nl. 14400 IN NSEC go.dns.miek.nl. CNAME RRSIG NSEC"), + test.RRSIG("archive.miek.nl. 14400 IN RRSIG NSEC 8 3 14400 20160426031301 20160327031301 12051 miek.nl. jEpx8lcp4do5fWXg="), + test.NSEC("miek.nl. 14400 IN NSEC a.miek.nl. A NS SOA MX AAAA RRSIG NSEC DNSKEY"), + test.RRSIG("miek.nl. 14400 IN RRSIG NSEC 8 2 14400 20160426031301 20160327031301 12051 miek.nl. mFfc3r/9PSC1H6oSpdC"), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160426031301 20160327031301 12051 miek.nl. FIrzy07acBbtyQczy1dc="), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + } + + m := tc.Msg() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + fm.ServeDNS(ctx, rec, m) + } +} + +const dbMiekNLSigned = ` +; File written on Sun Mar 27 04:13:01 2016 +; dnssec_signzone version 9.10.3-P4-Ubuntu +miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. ( + 1459051981 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 14400 ; minimum (4 hours) + ) + 1800 RRSIG SOA 8 2 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + FIrzy07acBzrf6kNW13Ypmq/ahojoMqOj0qJ + ixTevTvwOEcVuw9GlJoYIHTYg+hm1sZHtx9K + RiVmYsm8SHKsJA1WzixtT4K7vQvM+T+qbeOJ + xA6YTivKUcGRWRXQlOTUAlHS/KqBEfmxKgRS + 68G4oOEClFDSJKh7RbtyQczy1dc= ) + 1800 NS ext.ns.whyscream.net. + 1800 NS omval.tednet.nl. + 1800 NS linode.atoom.net. + 1800 NS ns-ext.nlnetlabs.nl. + 1800 RRSIG NS 8 2 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + ZLtsQhwaz+CwrgzgFiEAqbqS/JH65MYjziA3 + 6EXwlGDy41lcfGm71PpxA7cDzFhWNkJNk4QF + q48wtpP4IGPPpHbnJHKDUXj6se7S+ylAGbS+ + VgVJ4YaVcE6xA9ZVhVpz8CSSjeH34vmqq9xj + zmFjofuDvraZflHfNpztFoR1Vxs= ) + 1800 A 139.162.196.78 + 1800 RRSIG A 8 2 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + hl+6Q075tsCkxIqbop8zZ6U8rlFvooz7Izzx + MgCZYVLcg75El28EXKIhBfRb1dPaKbd+v+AD + wrJMHL131pY5sU2Ly05K+7CqmmyaXgDaVsKS + rSw/TbhGDIItBemeseeuXGAKAbY2+gE7kNN9 + mZoQ9hRB3SrxE2jhctv66DzYYQQ= ) + 1800 MX 1 aspmx.l.google.com. + 1800 MX 5 alt1.aspmx.l.google.com. + 1800 MX 5 alt2.aspmx.l.google.com. + 1800 MX 10 aspmx2.googlemail.com. + 1800 MX 10 aspmx3.googlemail.com. + 1800 RRSIG MX 8 2 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + kLqG+iOrKSzms1H9Et9me8Zts1rbyeCFSVQD + G9is/u6ec3Lqg2vwJddf/yRsjVpVgadWSAkc + GSDuD2dK8oBeP24axWc3Z1OY2gdMI7w+PKWT + Z+pjHVjbjM47Ii/a6jk5SYeOwpGMsdEwhtTP + vk2O2WGljifqV3uE7GshF5WNR10= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fef1:6735 + 1800 RRSIG AAAA 8 2 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + SsRTHytW4YTAuHovHQgfIMhNwMtMp4gaAU/Z + lgTO+IkBb9y9F8uHrf25gG6RqA1bnGV/gezV + NU5negXm50bf1BNcyn3aCwEbA0rCGYIL+nLJ + szlBVbBu6me/Ym9bbJlfgfHRDfsVy2ZkNL+B + jfNQtGCSDoJwshjcqJlfIVSardo= ) + 14400 NSEC a.miek.nl. A NS SOA MX AAAA RRSIG NSEC DNSKEY + 14400 RRSIG NSEC 8 2 14400 ( + 20160426031301 20160327031301 12051 miek.nl. + mFfc3r/9PSC1H6oSpdC+FDy/Iu02W2Tf0x+b + n6Lpe1gCC1uvcSUrrmBNlyAWRr5Zm+ZXssEb + cKddRGiu/5sf0bUWrs4tqokL/HUl10X/sBxb + HfwNAeD7R7+CkpMv67li5AhsDgmQzpX2r3P6 + /6oZyLvODGobysbmzeWM6ckE8IE= ) + 1800 DNSKEY 256 3 8 ( + AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6 + E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5EC + IoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb + 2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXH + Py7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz + ) ; ZSK; alg = RSASHA256; key id = 12051 + 1800 DNSKEY 257 3 8 ( + AwEAAcWdjBl4W4wh/hPxMDcBytmNCvEngIgB + 9Ut3C2+QI0oVz78/WK9KPoQF7B74JQ/mjO4f + vIncBmPp6mFNxs9/WQX0IXf7oKviEVOXLjct + R4D1KQLX0wprvtUIsQFIGdXaO6suTT5eDbSd + 6tTwu5xIkGkDmQhhH8OQydoEuCwV245ZwF/8 + AIsqBYDNQtQ6zhd6jDC+uZJXg/9LuPOxFHbi + MTjp6j3CCW0kHbfM/YHZErWWtjPj3U3Z7knQ + SIm5PO5FRKBEYDdr5UxWJ/1/20SrzI3iztvP + wHDsA2rdHm/4YRzq7CvG4N0t9ac/T0a0Sxba + /BUX2UVPWaIVBdTRBtgHi0s= + ) ; KSK; alg = RSASHA256; key id = 33694 + 1800 RRSIG DNSKEY 8 2 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + o/D6o8+/bNGQyyRvwZ2hM0BJ+3HirvNjZoko + yGhGe9sPSrYU39WF3JVIQvNJFK6W3/iwlKir + TPOeYlN6QilnztFq1vpCxwj2kxJaIJhZecig + LsKxY/fOHwZlIbBLZZadQG6JoGRLHnImSzpf + xtyVaXQtfnJFC07HHt9np3kICfE= ) + 1800 RRSIG DNSKEY 8 2 1800 ( + 20160426031301 20160327031301 33694 miek.nl. + Ak/mbbQVQV+nUgw5Sw/c+TSoYqIwbLARzuNE + QJvJNoRR4tKVOY6qSxQv+j5S7vzyORZ+yeDp + NlEa1T9kxZVBMABoOtLX5kRqZncgijuH8fxb + L57Sv2IzINI9+DOcy9Q9p9ygtwYzQKrYoNi1 + 0hwHi6emGkVG2gGghruMinwOJASGgQy487Yd + eIpcEKJRw73nxd2le/4/Vafy+mBpKWOczfYi + 5m9MSSxcK56NFYjPG7TvdIw0m70F/smY9KBP + pGWEdzRQDlqfZ4fpDaTAFGyRX0mPFzMbs1DD + 3hQ4LHUSi/NgQakdH9eF42EVEDeL4cI69K98 + 6NNk6X9TRslO694HKw== ) +a.miek.nl. 1800 IN A 139.162.196.78 + 1800 RRSIG A 8 3 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + lxLotCjWZ3kikNNcePu6HOCqMHDINKFRJRD8 + laz2KQ9DKtgXPdnRw5RJvVITSj8GUVzw1ec1 + CYVEKu/eMw/rc953Zns528QBypGPeMNLe2vu + C6a6UhZnGHA48dSd9EX33eSJs0MP9xsC9csv + LGdzYmv++eslkKxkhSOk2j/hTxk= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fef1:6735 + 1800 RRSIG AAAA 8 3 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + ji3QMlaUzlK85ppB5Pc+y2WnfqOi6qrm6dm1 + bXgsEov/5UV1Lmcv8+Y5NBbTbBlXGlWcpqNp + uWpf9z3lbguDWznpnasN2MM8t7yxo/Cr7WRf + QCzui7ewpWiA5hq7j0kVbM4nnDc6cO+U93hO + mMhVbeVI70HM2m0HaHkziEyzVZk= ) + 14400 NSEC archive.miek.nl. A AAAA RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20160426031301 20160327031301 12051 miek.nl. + GqnF6cut/KCxbnJj27MCjjVGkjObV0hLhHOP + E1/GXAUTEKG6BWxJq8hidS3p/yrOmP5PEL9T + 4FjBp0/REdVmGpuLaiHyMselES82p/uMMdY5 + QqRM6LHhZdO1zsRbyzOZbm5MsW6GR7K2kHlX + 9TdBIULiRRGPQ1QGQE1ipmSHEao= ) +archive.miek.nl. 1800 IN CNAME a.miek.nl. + 1800 RRSIG CNAME 8 3 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + s4zVJiDrVuUiUFr8CNQLuXYYfpqpl8rovL50 + BYsub/xK756NENiOTAOjYH6KYg7RSzsygJjV + YQwXolZly2/KXAr48SCtxzkGFxLexxiKcFaj + vm7ZDl7Btoa5l68qmBcxOX5E/W0IKITi4PNK + mhBs7dlaf0IbPGNgMxae72RosxM= ) + 14400 NSEC go.dns.miek.nl. CNAME RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20160426031301 20160327031301 12051 miek.nl. + jEp7LsoK++/PRFh2HieLzasA1jXBpp90NyDf + RfpfOxdM69yRKfvXMc2bazIiMuDhxht79dGI + Gj02cn1cvX60SlaHkeFtqTdJcHdK9rbI65EK + YHFZFzGh9XVnuMJKpUsm/xS1dnUSAnXN8q+0 + xBlUDlQpsAFv/cx8lcp4do5fWXg= ) +go.dns.miek.nl. 1800 IN TXT "Hello!" + 1800 RRSIG TXT 8 4 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + O0uo1NsXTq2TTfgOmGbHQQEchrcpllaDAMMX + dTDizw3t+vZ5SR32qJ8W7y6VXLgUqJgcdRxS + Fou1pp+t5juRZSQ0LKgxMpZAgHorkzPvRf1b + E9eBKrDSuLGagsQRwHeldFGFgsXtCbf07vVH + zoKR8ynuG4/cAoY0JzMhCts+56U= ) + 14400 NSEC www.miek.nl. TXT RRSIG NSEC + 14400 RRSIG NSEC 8 4 14400 ( + 20160426031301 20160327031301 12051 miek.nl. + BW6qo7kYe3Z+Y0ebaVTWTy1c3bpdf8WUEoXq + WDQxLDEj2fFiuEBDaSN5lTWRg3wj8kZmr6Uk + LvX0P29lbATFarIgkyiAdbOEdaf88nMfqBW8 + z2T5xrPQcN0F13uehmv395yAJs4tebRxErMl + KdkVF0dskaDvw8Wo3YgjHUf6TXM= ) +www.miek.nl. 1800 IN CNAME a.miek.nl. + 1800 RRSIG CNAME 8 3 1800 ( + 20160426031301 20160327031301 12051 miek.nl. + MiQQh2lScoNiNVZmMJaypS+wDL2Lar4Zw1zF + Uo4tL16BfQOt7yl8gXdAH2JMFqoKAoIdM2K6 + XwFOwKTOGSW0oNCOcaE7ts+1Z1U0H3O2tHfq + FAzfg1s9pQ5zxk8J/bJgkVIkw2/cyB0y1/PK + EmIqvChBSb4NchTuMCSqo63LJM8= ) + 14400 NSEC miek.nl. CNAME RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20160426031301 20160327031301 12051 miek.nl. + OPPZ8iaUPrVKEP4cqeCiiv1WLRAY30GRIhc/ + me0gBwFkbmTEnvB+rUp831OJZDZBNKv4QdZj + Uyc26wKUOQeUyMJqv4IRDgxH7nq9GB5JRjYZ + IVxtGD1aqWLXz+8aMaf9ARJjtYUd3K4lt8Wz + LbJSo5Wdq7GOWqhgkY5n3XD0/FA= )` diff --git a/plugin/file/dnssex_test.go b/plugin/file/dnssex_test.go new file mode 100644 index 0000000..d9a0a45 --- /dev/null +++ b/plugin/file/dnssex_test.go @@ -0,0 +1,145 @@ +package file + +const dbDnssexNLSigned = ` +; File written on Tue Mar 29 21:02:24 2016 +; dnssec_signzone version 9.10.3-P4-Ubuntu +dnssex.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. ( + 1459281744 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 14400 ; minimum (4 hours) + ) + 1800 RRSIG SOA 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + CA/Y3m9hCOiKC/8ieSOv8SeP964BUdG/8MC3 + WtKljUosK9Z9bBGrVizDjjqgq++lyH8BZJcT + aabAsERs4xj5PRtcxicwQXZACX5VYjXHQeZm + CyytFU5wq2gcXSmvUH86zZzftx3RGPvn1aOo + TlcvoC3iF8fYUCpROlUS0YR8Cdw= ) + 1800 NS omval.tednet.nl. + 1800 NS linode.atoom.net. + 1800 NS ns-ext.nlnetlabs.nl. + 1800 RRSIG NS 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + dLIeEvP86jj5nd3orv9bH7hTvkblF4Na0sbl + k6fJA6ha+FPN1d6Pig3NNEEVQ/+wlOp/JTs2 + v07L7roEEUCbBprI8gMSld2gFDwNLW3DAB4M + WD/oayYdAnumekcLzhgvWixTABjWAGRTGQsP + sVDFXsGMf9TGGC9FEomgkCVeNC0= ) + 1800 A 139.162.196.78 + 1800 RRSIG A 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + LKJKLzPiSEDWOLAag2YpfD5EJCuDcEAJu+FZ + Xy+4VyOv9YvRHCTL4vbrevOo5+XymY2RxU1q + j+6leR/Fe7nlreSj2wzAAk2bIYn4m6r7hqeO + aKZsUFfpX8cNcFtGEywfHndCPELbRxFeEziP + utqHFLPNMX5nYCpS28w4oJ5sAnM= ) + 1800 TXT "Doing It Safe Is Better" + 1800 RRSIG TXT 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + f6S+DUfJK1UYdOb3AHgUXzFTTtu+yLp/Fv7S + Hv0CAGhXAVw+nBbK719igFvBtObS33WKwzxD + 1pQNMaJcS6zeevtD+4PKB1KDC4fyJffeEZT6 + E30jGR8Y29/xA+Fa4lqDNnj9zP3b8TiABCle + ascY5abkgWCALLocFAzFJQ/27YQ= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fef1:6735 + 1800 RRSIG AAAA 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + PWcPSawEUBAfCuv0liEOQ8RYe7tfNW4rubIJ + LE+dbrub1DUer3cWrDoCYFtOufvcbkYJQ2CQ + AGjJmAQ5J2aqYDOPMrKa615V0KT3ifbZJcGC + gkIic4U/EXjaQpRoLdDzR9MyVXOmbA6sKYzj + ju1cNkLqM8D7Uunjl4pIr6rdSFo= ) + 14400 NSEC *.dnssex.nl. A NS SOA TXT AAAA RRSIG NSEC DNSKEY + 14400 RRSIG NSEC 8 2 14400 ( + 20160428190224 20160329190224 14460 dnssex.nl. + oIvM6JZIlNc1aNKGTxv58ApSnDr1nDPPgnD9 + 9oJZRIn7eb5WnpeDz2H3z5+x6Bhlp5hJJaUp + KJ3Ss6Jg/IDnrmIvKmgq6L6gHj1Y1IiHmmU8 + VeZTRzdTsDx/27OsN23roIvsytjveNSEMfIm + iLZ23x5kg1kBdJ9p3xjYHm5lR+8= ) + 1800 DNSKEY 256 3 8 ( + AwEAAazSO6uvLPEVknDA8yxjFe8nnAMU7txp + wb19k55hQ81WV3G4bpBM1NdN6sbYHrkXaTNx + 2bQWAkvX6pz0XFx3z/MPhW+vkakIWFYpyQ7R + AT5LIJfToVfiCDiyhhF0zVobKBInO9eoGjd9 + BAW3TUt+LmNAO/Ak5D5BX7R3CuA7v9k7 + ) ; ZSK; alg = RSASHA256; key id = 14460 + 1800 DNSKEY 257 3 8 ( + AwEAAbyeaV9zg0IqdtgYoqK5jJ239anzwG2i + gvH1DxSazLyaoNvEkCIvPgMLW/JWfy7Z1mQp + SMy9DtzL5pzRyQgw7kIeXLbi6jufUFd9pxN+ + xnzKLf9mY5AcnGToTrbSL+jnMT67wG+c34+Q + PeVfucHNUePBxsbz2+4xbXiViSQyCQGv + ) ; KSK; alg = RSASHA256; key id = 18772 + 1800 RRSIG DNSKEY 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + cFSFtJE+DBGNxb52AweFaVHBe5Ue5MDpqNdC + TIneUnEhP2m+vK4zJ/TraK0WdQFpsX63pod8 + PZ9y03vHUfewivyonCCBD3DcNdoU9subhN22 + tez9Ct8Z5/9E4RAz7orXal4M1VUEhRcXSEH8 + SJW20mfVsqJAiKqqNeGB/pAj23I= ) + 1800 RRSIG DNSKEY 8 2 1800 ( + 20160428190224 20160329190224 18772 dnssex.nl. + oiiwo/7NYacePqohEp50261elhm6Dieh4j2S + VZGAHU5gqLIQeW9CxKJKtSCkBVgUo4cvO4Rn + 2tzArAuclDvBrMXRIoct8u7f96moeFE+x5FI + DYqICiV6k449ljj9o4t/5G7q2CRsEfxZKpTI + A/L0+uDk0RwVVzL45+TnilcsmZs= ) +*.dnssex.nl. 1800 IN TXT "Doing It Safe Is Better" + 1800 RRSIG TXT 8 2 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + FUZSTyvZfeuuOpCmNzVKOfITRHJ6/ygjmnnb + XGBxVUyQjoLuYXwD5XqZWGw4iKH6QeSDfGCx + 4MPqA4qQmW7Wwth7mat9yMfA4+p2sO84bysl + 7/BG9+W2G+q1uQiM9bX9V42P2X/XuW5Y/t9Y + 8u1sljQ7D8WwS6naH/vbaJxnDBw= ) + 14400 NSEC a.dnssex.nl. TXT RRSIG NSEC + 14400 RRSIG NSEC 8 2 14400 ( + 20160428190224 20160329190224 14460 dnssex.nl. + os6INm6q2eXknD5z8TpfbK00uxVbQefMvHcR + /RNX/kh0xXvzAaaDOV+Ge/Ko+2dXnKP+J1LY + G9ffXNpdbaQy5ygzH5F041GJst4566GdG/jt + 7Z7vLHYxEBTpZfxo+PLsXQXH3VTemZyuWyDf + qJzafXJVH1F0nDrcXmMlR6jlBHA= ) +www.dnssex.nl. 1800 IN CNAME a.dnssex.nl. + 1800 RRSIG CNAME 8 3 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + Omv42q/uVvdNsWQoSrQ6m6w6U7r7Abga7uF4 + 25b3gZlse0C+WyMyGFMGUbapQm7azvBpreeo + uKJHjzd+ufoG+Oul6vU9vyoj+ejgHzGLGbJQ + HftfP+UqP5SWvAaipP/LULTWKPuiBcLDLiBI + PGTfsq0DB6R+qCDTV0fNnkgxEBQ= ) + 14400 NSEC dnssex.nl. CNAME RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20160428190224 20160329190224 14460 dnssex.nl. + TBN3ddfZW+kC84/g3QlNNJMeLZoyCalPQylt + KXXLPGuxfGpl3RYRY8KaHbP+5a8MnHjqjuMB + Lofb7yKMFxpSzMh8E36vnOqry1mvkSakNj9y + 9jM8PwDjcpYUwn/ql76MsmNgEV5CLeQ7lyH4 + AOrL79yOSQVI3JHJIjKSiz88iSw= ) +a.dnssex.nl. 1800 IN A 139.162.196.78 + 1800 RRSIG A 8 3 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + OXHpFj9nSpKi5yA/ULH7MOpGAWfyJ2yC/2xa + Pw0fqSY4QvcRt+V3adcFA4H9+P1b32GpxEjB + lXmCJID+H4lYkhUR4r4IOZBVtKG2SJEBZXip + pH00UkOIBiXxbGzfX8VL04v2G/YxUgLW57kA + aknaeTOkJsO20Y+8wmR9EtzaRFI= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fef1:6735 + 1800 RRSIG AAAA 8 3 1800 ( + 20160428190224 20160329190224 14460 dnssex.nl. + jrepc/VnRzJypnrG0WDEqaAr3HMjWrPxJNX0 + 86gbFjZG07QxBmrA1rj0jM9YEWTjjyWb2tT7 + lQhzKDYX/0XdOVUeeOM4FoSks80V+pWR8fvj + AZ5HmX69g36tLosMDKNR4lXcrpv89QovG4Hr + /r58fxEKEFJqrLDjMo6aOrg+uKA= ) + 14400 NSEC www.dnssex.nl. A AAAA RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20160428190224 20160329190224 14460 dnssex.nl. + S+UM62wXRNNFN3QDWK5YFWUbHBXC4aqaqinZ + A2ZDeC+IQgyw7vazPz7cLI5T0YXXks0HTMlr + soEjKnnRZsqSO9EuUavPNE1hh11Jjm0fB+5+ + +Uro0EmA5Dhgc0Z2VpbXVQEhNDf/pI1gem15 + RffN2tBYNykZn4Has2ySgRaaRYQ= )` diff --git a/plugin/file/ds_test.go b/plugin/file/ds_test.go new file mode 100644 index 0000000..74f7bbd --- /dev/null +++ b/plugin/file/ds_test.go @@ -0,0 +1,77 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var dsTestCases = []test.Case{ + { + Qname: "a.delegated.miek.nl.", Qtype: dns.TypeDS, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + Qname: "_udp.delegated.miek.nl.", Qtype: dns.TypeDS, + Ns: []dns.RR{ + test.NS("delegated.miek.nl. 1800 IN NS a.delegated.miek.nl."), + test.NS("delegated.miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + }, + Extra: []dns.RR{ + test.A("a.delegated.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.delegated.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + }, + { + // This works *here* because we skip the server routing for DS in core/dnsserver/server.go + Qname: "_udp.miek.nl.", Qtype: dns.TypeDS, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeDS, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, +} + +func TestLookupDS(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekNLDelegation), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + + for _, tc := range dsTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/file/ent_test.go b/plugin/file/ent_test.go new file mode 100644 index 0000000..73f5085 --- /dev/null +++ b/plugin/file/ent_test.go @@ -0,0 +1,159 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var entTestCases = []test.Case{ + { + Qname: "b.c.miek.nl.", Qtype: dns.TypeA, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "b.c.miek.nl.", Qtype: dns.TypeA, Do: true, + Ns: []dns.RR{ + test.NSEC("a.miek.nl. 14400 IN NSEC a.b.c.miek.nl. A RRSIG NSEC"), + test.RRSIG("a.miek.nl. 14400 IN RRSIG NSEC 8 3 14400 20160502144311 20160402144311 12051 miek.nl. d5XZEy6SUpq98ZKUlzqhAfkLI9pQPc="), + test.RRSIG("miek.nl. 1800 IN RRSIG SOA 8 2 1800 20160502144311 20160402144311 12051 miek.nl. KegoBxA3Tbrhlc4cEdkRiteIkOfsq"), + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, +} + +func TestLookupEnt(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekENTNL), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + + for _, tc := range entTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +const dbMiekENTNL = `; File written on Sat Apr 2 16:43:11 2016 +; dnssec_signzone version 9.10.3-P4-Ubuntu +miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 14400 ; minimum (4 hours) + ) + 1800 RRSIG SOA 8 2 1800 ( + 20160502144311 20160402144311 12051 miek.nl. + KegoBxA3Tbrhlc4cEdkRiteIkOfsqD4oCLLM + ISJ5bChWy00LGHUlAnHVu5Ti96hUjVNmGSxa + xtGSuAAMFCr52W8pAB8LBIlu9B6QZUPHMccr + SuzxAX3ioawk2uTjm+k8AGPT4RoQdXemGLAp + zJTASolTVmeMTh5J0sZTZJrtvZ0= ) + 1800 NS linode.atoom.net. + 1800 RRSIG NS 8 2 1800 ( + 20160502144311 20160402144311 12051 miek.nl. + m0cOHL6Rre/0jZPXe+0IUjs/8AFASRCvDbSx + ZQsRDSlZgS6RoMP3OC77cnrKDVlfZ2Vhq3Ce + nYPoGe0/atB92XXsilmstx4HTSU64gsV9iLN + Xkzk36617t7zGOl/qumqfaUXeA9tihItzEim + 6SGnufVZI4o8xeyaVCNDDuN0bvY= ) + 14400 NSEC a.miek.nl. NS SOA RRSIG NSEC DNSKEY + 14400 RRSIG NSEC 8 2 14400 ( + 20160502144311 20160402144311 12051 miek.nl. + BCWVgwxWrs4tBjS9QXKkftCUbiLi40NyH1yA + nbFy1wCKQ2jDH00810+ia4b66QrjlAKgxE9z + 9U7MKSMV86sNkyAtlCi+2OnjtWF6sxPdJO7k + CHeg46XBjrQuiJRY8CneQX56+IEPdufLeqPR + l+ocBQ2UkGhXmQdWp3CFDn2/eqU= ) + 1800 DNSKEY 256 3 8 ( + AwEAAcNEU67LJI5GEgF9QLNqLO1SMq1EdoQ6 + E9f85ha0k0ewQGCblyW2836GiVsm6k8Kr5EC + IoMJ6fZWf3CQSQ9ycWfTyOHfmI3eQ/1Covhb + 2y4bAmL/07PhrL7ozWBW3wBfM335Ft9xjtXH + Py7ztCbV9qZ4TVDTW/Iyg0PiwgoXVesz + ) ; ZSK; alg = RSASHA256; key id = 12051 + 1800 DNSKEY 257 3 8 ( + AwEAAcWdjBl4W4wh/hPxMDcBytmNCvEngIgB + 9Ut3C2+QI0oVz78/WK9KPoQF7B74JQ/mjO4f + vIncBmPp6mFNxs9/WQX0IXf7oKviEVOXLjct + R4D1KQLX0wprvtUIsQFIGdXaO6suTT5eDbSd + 6tTwu5xIkGkDmQhhH8OQydoEuCwV245ZwF/8 + AIsqBYDNQtQ6zhd6jDC+uZJXg/9LuPOxFHbi + MTjp6j3CCW0kHbfM/YHZErWWtjPj3U3Z7knQ + SIm5PO5FRKBEYDdr5UxWJ/1/20SrzI3iztvP + wHDsA2rdHm/4YRzq7CvG4N0t9ac/T0a0Sxba + /BUX2UVPWaIVBdTRBtgHi0s= + ) ; KSK; alg = RSASHA256; key id = 33694 + 1800 RRSIG DNSKEY 8 2 1800 ( + 20160502144311 20160402144311 12051 miek.nl. + YNpi1jRDQKpnsQEjIjxqy+kJGaYnV16e8Iug + 40c82y4pee7kIojFUllSKP44qiJpCArxF557 + tfjfwBd6c4hkqCScGPZXJ06LMyG4u//rhVMh + 4hyKcxzQFKxmrFlj3oQGksCI8lxGX6RxiZuR + qv2ol2lUWrqetpAL+Zzwt71884E= ) + 1800 RRSIG DNSKEY 8 2 1800 ( + 20160502144311 20160402144311 33694 miek.nl. + jKpLDEeyadgM0wDgzEk6sBBdWr2/aCrkAOU/ + w6dYIafN98f21oIYQfscV1gc7CTsA0vwzzUu + x0QgwxoNLMvSxxjOiW/2MzF8eozczImeCWbl + ad/pVCYH6Jn5UBrZ5RCWMVcs2RP5KDXWeXKs + jEN/0EmQg5qNd4zqtlPIQinA9I1HquJAnS56 + pFvYyGIbZmGEbhR18sXVBeTWYr+zOMHn2quX + 0kkrx2udz+sPg7i4yRsLdhw138gPRy1qvbaC + 8ELs1xo1mC9pTlDOhz24Q3iXpVAU1lXLYOh9 + nUP1/4UvZEYXHBUQk/XPRciojniWjAF825x3 + QoSivMHblBwRdAKJSg== ) +a.miek.nl. 1800 IN A 127.0.0.1 + 1800 RRSIG A 8 3 1800 ( + 20160502144311 20160402144311 12051 miek.nl. + lUOYdSxScjyYz+Ebc+nb6iTNgCohqj7K+Dat + 97KE7haV2nP3LxdYuDCJYZpeyhsXDLHd4bFI + bInYPwJiC6DUCxPCuCWy0KYlZOWW8KCLX3Ia + BOPQbvIwLsJhnX+/tyMD9mXortoqATO79/6p + nNxvFeM8pFDwaih17fXMuFR/BsI= ) + 14400 NSEC a.b.c.miek.nl. A RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20160502144311 20160402144311 12051 miek.nl. + d5XZEy6SUp+TPRJQED+0R65zf2Yeo/1dlEA2 + jYYvkXGSHXke4sg9nH8U3nr1rLcuqA1DsQgH + uMIjdENvXuZ+WCSwvIbhC+JEI6AyQ6Gfaf/D + I3mfu60C730IRByTrKM5C2rt11lwRQlbdaUY + h23/nn/q98ZKUlzqhAfkLI9pQPc= ) +a.b.c.miek.nl. 1800 IN A 127.0.0.1 + 1800 RRSIG A 8 5 1800 ( + 20160502144311 20160402144311 12051 miek.nl. + FwgU5+fFD4hEebco3gvKQt3PXfY+dcOJr8dl + Ky4WLsONIdhP+4e9oprPisSLxImErY21BcrW + xzu1IZrYDsS8XBVV44lBx5WXEKvAOrUcut/S + OWhFZW7ncdIQCp32ZBIatiLRJEqXUjx+guHs + noFLiHix35wJWsRKwjGLIhH1fbs= ) + 14400 NSEC miek.nl. A RRSIG NSEC + 14400 RRSIG NSEC 8 5 14400 ( + 20160502144311 20160402144311 12051 miek.nl. + lXgOqm9/jRRYvaG5jC1CDvTtGYxMroTzf4t4 + jeYGb60+qI0q9sHQKfAJvoQ5o8o1qfR7OuiF + f544ipYT9eTcJRyGAOoJ37yMie7ZIoVJ91tB + r8YdzZ9Q6x3v1cbwTaQiacwhPZhGYOw63qIs + q5IQErIPos2sNk+y9D8BEce2DO4= )` diff --git a/plugin/file/example_org.go b/plugin/file/example_org.go new file mode 100644 index 0000000..eba18e0 --- /dev/null +++ b/plugin/file/example_org.go @@ -0,0 +1,113 @@ +package file + +// exampleOrgSigned is a fake signed example.org zone with two delegations, +// one signed (with DSs) and one "normal". +const exampleOrgSigned = ` +example.org. 1800 IN SOA a.iana-servers.net. devnull.example.org. ( + 1282630057 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 14400 ; minimum (4 hours) + ) + 1800 RRSIG SOA 13 2 1800 ( + 20161129153240 20161030153240 49035 example.org. + GVnMpFmN+6PDdgCtlYDEYBsnBNDgYmEJNvos + Bk9+PNTPNWNst+BXCpDadTeqRwrr1RHEAQ7j + YWzNwqn81pN+IA== ) + 1800 NS a.iana-servers.net. + 1800 NS b.iana-servers.net. + 1800 RRSIG NS 13 2 1800 ( + 20161129153240 20161030153240 49035 example.org. + llrHoIuwjnbo28LOt4p5zWAs98XGqrXicKVI + Qxyaf/ORM8boJvW2XrKr3nj6Y8FKMhzd287D + 5PBzVCL6MZyjQg== ) + 14400 NSEC a.example.org. NS SOA RRSIG NSEC DNSKEY + 14400 RRSIG NSEC 13 2 14400 ( + 20161129153240 20161030153240 49035 example.org. + BQROf1swrmYi3GqpP5M/h5vTB8jmJ/RFnlaX + 7fjxvV7aMvXCsr3ekWeB2S7L6wWFihDYcKJg + 9BxVPqxzBKeaqg== ) + 1800 DNSKEY 256 3 13 ( + UNTqlHbC51EbXuY0rshW19Iz8SkCuGVS+L0e + bQj53dvtNlaKfWmtTauC797FoyVLbQwoMy/P + G68SXgLCx8g+9g== + ) ; ZSK; alg = ECDSAP256SHA256; key id = 49035 + 1800 RRSIG DNSKEY 13 2 1800 ( + 20161129153240 20161030153240 49035 example.org. + LnLHyqYJaCMOt7EHB4GZxzAzWLwEGCTFiEhC + jj1X1VuQSjJcN42Zd3yF+jihSW6huknrig0Z + Mqv0FM6mJ/qPKg== ) +a.delegated.example.org. 1800 IN A 139.162.196.78 + 1800 TXT "obscured" + 1800 AAAA 2a01:7e00::f03c:91ff:fef1:6735 +archive.example.org. 1800 IN CNAME a.example.org. + 1800 RRSIG CNAME 13 3 1800 ( + 20161129153240 20161030153240 49035 example.org. + SDFW1z/PN9knzH8BwBvmWK0qdIwMVtGrMgRw + 7lgy4utRrdrRdCSLZy3xpkmkh1wehuGc4R0S + 05Z3DPhB0Fg5BA== ) + 14400 NSEC delegated.example.org. CNAME RRSIG NSEC + 14400 RRSIG NSEC 13 3 14400 ( + 20161129153240 20161030153240 49035 example.org. + DQqLSVNl8F6v1K09wRU6/M6hbHy2VUddnOwn + JusJjMlrAOmoOctCZ/N/BwqCXXBA+d9yFGdH + knYumXp+BVPBAQ== ) +www.example.org. 1800 IN CNAME a.example.org. + 1800 RRSIG CNAME 13 3 1800 ( + 20161129153240 20161030153240 49035 example.org. + adzujOxCV0uBV4OayPGfR11iWBLiiSAnZB1R + slmhBFaDKOKSNYijGtiVPeaF+EuZs63pzd4y + 6Nm2Iq9cQhAwAA== ) + 14400 NSEC example.org. CNAME RRSIG NSEC + 14400 RRSIG NSEC 13 3 14400 ( + 20161129153240 20161030153240 49035 example.org. + jy3f96GZGBaRuQQjuqsoP1YN8ObZF37o+WkV + PL7TruzI7iNl0AjrUDy9FplP8Mqk/HWyvlPe + N3cU+W8NYlfDDQ== ) +a.example.org. 1800 IN A 139.162.196.78 + 1800 RRSIG A 13 3 1800 ( + 20161129153240 20161030153240 49035 example.org. + 41jFz0Dr8tZBN4Kv25S5dD4vTmviFiLx7xSA + qMIuLFm0qibKL07perKpxqgLqM0H1wreT4xz + I9Y4Dgp1nsOuMA== ) + 1800 AAAA 2a01:7e00::f03c:91ff:fef1:6735 + 1800 RRSIG AAAA 13 3 1800 ( + 20161129153240 20161030153240 49035 example.org. + brHizDxYCxCHrSKIu+J+XQbodRcb7KNRdN4q + VOWw8wHqeBsFNRzvFF6jwPQYphGP7kZh1KAb + VuY5ZVVhM2kHjw== ) + 14400 NSEC archive.example.org. A AAAA RRSIG NSEC + 14400 RRSIG NSEC 13 3 14400 ( + 20161129153240 20161030153240 49035 example.org. + zIenVlg5ScLr157EWigrTGUgrv7W/1s49Fic + i2k+OVjZfT50zw+q5X6DPKkzfAiUhIuqs53r + hZUzZwV/1Wew9Q== ) +delegated.example.org. 1800 IN NS a.delegated.example.org. + 1800 IN NS ns-ext.nlnetlabs.nl. + 1800 DS 10056 5 1 ( + EE72CABD1927759CDDA92A10DBF431504B9E + 1F13 ) + 1800 DS 10056 5 2 ( + E4B05F87725FA86D9A64F1E53C3D0E625094 + 6599DFE639C45955B0ED416CDDFA ) + 1800 RRSIG DS 13 3 1800 ( + 20161129153240 20161030153240 49035 example.org. + rlNNzcUmtbjLSl02ZzQGUbWX75yCUx0Mug1j + HtKVqRq1hpPE2S3863tIWSlz+W9wz4o19OI4 + jbznKKqk+DGKog== ) + 14400 NSEC sub.example.org. NS DS RRSIG NSEC + 14400 RRSIG NSEC 13 3 14400 ( + 20161129153240 20161030153240 49035 example.org. + lNQ5kRTB26yvZU5bFn84LYFCjwWTmBcRCDbD + cqWZvCSw4LFOcqbz1/wJKIRjIXIqnWIrfIHe + fZ9QD5xZsrPgUQ== ) +sub.example.org. 1800 IN NS sub1.example.net. + 1800 IN NS sub2.example.net. + 14400 NSEC www.example.org. NS RRSIG NSEC + 14400 RRSIG NSEC 13 3 14400 ( + 20161129153240 20161030153240 49035 example.org. + VYjahdV+TTkA3RBdnUI0hwXDm6U5k/weeZZr + ix1znORpOELbeLBMJW56cnaG+LGwOQfw9qqj + bOuULDst84s4+g== ) +` diff --git a/plugin/file/file.go b/plugin/file/file.go new file mode 100644 index 0000000..f50c3d0 --- /dev/null +++ b/plugin/file/file.go @@ -0,0 +1,164 @@ +// Package file implements a file backend. +package file + +import ( + "context" + "fmt" + "io" + + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/transfer" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("file") + +type ( + // File is the plugin that reads zone data from disk. + File struct { + Next plugin.Handler + Zones + transfer *transfer.Transfer + } + + // Zones maps zone names to a *Zone. + Zones struct { + Z map[string]*Zone // A map mapping zone (origin) to the Zone's data + Names []string // All the keys from the map Z as a string slice. + } +) + +// ServeDNS implements the plugin.Handle interface. +func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + qname := state.Name() + // TODO(miek): match the qname better in the map + zone := plugin.Zones(f.Zones.Names).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r) + } + + z, ok := f.Zones.Z[zone] + if !ok || z == nil { + return dns.RcodeServerFailure, nil + } + + // If transfer is not loaded, we'll see these, answer with refused (no transfer allowed). + if state.QType() == dns.TypeAXFR || state.QType() == dns.TypeIXFR { + return dns.RcodeRefused, nil + } + + // This is only for when we are a secondary zones. + if r.Opcode == dns.OpcodeNotify { + if z.isNotify(state) { + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + w.WriteMsg(m) + + log.Infof("Notify from %s for %s: checking transfer", state.IP(), zone) + ok, err := z.shouldTransfer() + if ok { + z.TransferIn() + } else { + log.Infof("Notify from %s for %s: no SOA serial increase seen", state.IP(), zone) + } + if err != nil { + log.Warningf("Notify from %s for %s: failed primary check: %s", state.IP(), zone, err) + } + return dns.RcodeSuccess, nil + } + log.Infof("Dropping notify from %s for %s", state.IP(), zone) + return dns.RcodeSuccess, nil + } + + z.RLock() + exp := z.Expired + z.RUnlock() + if exp { + log.Errorf("Zone %s is expired", zone) + return dns.RcodeServerFailure, nil + } + + answer, ns, extra, result := z.Lookup(ctx, state, qname) + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + m.Answer, m.Ns, m.Extra = answer, ns, extra + + switch result { + case Success: + case NoData: + case NameError: + m.Rcode = dns.RcodeNameError + case Delegation: + m.Authoritative = false + case ServerFailure: + // If the result is SERVFAIL and the answer is non-empty, then the SERVFAIL came from an + // external CNAME lookup and the answer contains the CNAME with no target record. We should + // write the CNAME record to the client instead of sending an empty SERVFAIL response. + if len(m.Answer) == 0 { + return dns.RcodeServerFailure, nil + } + // The rcode in the response should be the rcode received from the target lookup. RFC 6604 section 3 + m.Rcode = dns.RcodeServerFailure + } + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +// Name implements the Handler interface. +func (f File) Name() string { return "file" } + +type serialErr struct { + err string + zone string + origin string + serial int64 +} + +func (s *serialErr) Error() string { + return fmt.Sprintf("%s for origin %s in file %s, with %d SOA serial", s.err, s.origin, s.zone, s.serial) +} + +// Parse parses the zone in filename and returns a new Zone or an error. +// If serial >= 0 it will reload the zone, if the SOA hasn't changed +// it returns an error indicating nothing was read. +func Parse(f io.Reader, origin, fileName string, serial int64) (*Zone, error) { + zp := dns.NewZoneParser(f, dns.Fqdn(origin), fileName) + zp.SetIncludeAllowed(true) + z := NewZone(origin, fileName) + seenSOA := false + for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { + if err := zp.Err(); err != nil { + return nil, err + } + + if !seenSOA { + if s, ok := rr.(*dns.SOA); ok { + seenSOA = true + + // -1 is valid serial is we failed to load the file on startup. + + if serial >= 0 && s.Serial == uint32(serial) { // same serial + return nil, &serialErr{err: "no change in SOA serial", origin: origin, zone: fileName, serial: serial} + } + } + } + + if err := z.Insert(rr); err != nil { + return nil, err + } + } + if !seenSOA { + return nil, fmt.Errorf("file %q has no SOA record for origin %s", fileName, origin) + } + + return z, nil +} diff --git a/plugin/file/file_test.go b/plugin/file/file_test.go new file mode 100644 index 0000000..0e4050e --- /dev/null +++ b/plugin/file/file_test.go @@ -0,0 +1,31 @@ +package file + +import ( + "strings" + "testing" +) + +func BenchmarkFileParseInsert(b *testing.B) { + for i := 0; i < b.N; i++ { + Parse(strings.NewReader(dbMiekENTNL), testzone, "stdin", 0) + } +} + +func TestParseNoSOA(t *testing.T) { + _, err := Parse(strings.NewReader(dbNoSOA), "example.org.", "stdin", 0) + if err == nil { + t.Fatalf("Zone %q should have failed to load", "example.org.") + } + if !strings.Contains(err.Error(), "no SOA record") { + t.Fatalf("Zone %q should have failed to load with no soa error: %s", "example.org.", err) + } +} + +const dbNoSOA = ` +$TTL 1M +$ORIGIN example.org. + +www IN A 192.168.0.14 +mail IN A 192.168.0.15 +imap IN CNAME mail +` diff --git a/plugin/file/fuzz.go b/plugin/file/fuzz.go new file mode 100644 index 0000000..9c59ab8 --- /dev/null +++ b/plugin/file/fuzz.go @@ -0,0 +1,50 @@ +//go:build gofuzz + +package file + +import ( + "strings" + + "github.com/coredns/coredns/plugin/pkg/fuzz" + "github.com/coredns/coredns/plugin/test" +) + +// Fuzz fuzzes file. +func Fuzz(data []byte) int { + name := "miek.nl." + zone, _ := Parse(strings.NewReader(fuzzMiekNL), name, "stdin", 0) + f := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{name: zone}, Names: []string{name}}} + + return fuzz.Do(f, data) +} + +const fuzzMiekNL = ` +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + IN NS ns-ext.nlnetlabs.nl. + IN NS omval.tednet.nl. + IN NS ext.ns.whyscream.net. + + IN MX 1 aspmx.l.google.com. + IN MX 5 alt1.aspmx.l.google.com. + IN MX 5 alt2.aspmx.l.google.com. + IN MX 10 aspmx2.googlemail.com. + IN MX 10 aspmx3.googlemail.com. + + IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 + +a IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +www IN CNAME a +archive IN CNAME a + +srv IN SRV 10 10 8080 a.miek.nl. +mx IN MX 10 a.miek.nl.` diff --git a/plugin/file/glue_test.go b/plugin/file/glue_test.go new file mode 100644 index 0000000..eeddc4e --- /dev/null +++ b/plugin/file/glue_test.go @@ -0,0 +1,254 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// another personal zone (helps in testing as my secondary is NSD, atoom = atom in English. +var atoomTestCases = []test.Case{ + { + Qname: atoom, Qtype: dns.TypeNS, Do: true, + Answer: []dns.RR{ + test.NS("atoom.net. 1800 IN NS linode.atoom.net."), + test.NS("atoom.net. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.NS("atoom.net. 1800 IN NS omval.tednet.nl."), + test.RRSIG("atoom.net. 1800 IN RRSIG NS 8 2 1800 20170112031301 20161213031301 53289 atoom.net. DLe+G1 jlw="), + }, + Extra: []dns.RR{ + // test.OPT(4096, true), // added by server, not test in this unit test. + test.A("linode.atoom.net. 1800 IN A 176.58.119.54"), + test.AAAA("linode.atoom.net. 1800 IN AAAA 2a01:7e00::f03c:91ff:fe79:234c"), + test.RRSIG("linode.atoom.net. 1800 IN RRSIG A 8 3 1800 20170112031301 20161213031301 53289 atoom.net. Z4Ka4OLDoyxj72CL vkI="), + test.RRSIG("linode.atoom.net. 1800 IN RRSIG AAAA 8 3 1800 20170112031301 20161213031301 53289 atoom.net. l+9Qc914zFH/okG2fzJ1q olQ="), + }, + }, +} + +func TestLookupGlue(t *testing.T) { + zone, err := Parse(strings.NewReader(dbAtoomNetSigned), atoom, "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{atoom: zone}, Names: []string{atoom}}} + ctx := context.TODO() + + for _, tc := range atoomTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +const dbAtoomNetSigned = ` +; File written on Tue Dec 13 04:13:01 2016 +; dnssec_signzone version 9.10.3-P4-Debian +atoom.net. 1800 IN SOA linode.atoom.net. miek.miek.nl. ( + 1481602381 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 14400 ; minimum (4 hours) + ) + 1800 RRSIG SOA 8 2 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + GZ30uFuGATKzwHXgpEwK70qjdXSAqmbB5d4z + e7WTibvJDPLa1ptZBI7Zuod2KMOkT1ocSvhL + U7makhdv0BQx+5RSaP25mAmPIzfU7/T7R+DJ + 5q1GLlDSvOprfyMUlwOgZKZinesSdUa9gRmu + 8E+XnPNJ/jcTrGzzaDjn1/irrM0= ) + 1800 NS omval.tednet.nl. + 1800 NS linode.atoom.net. + 1800 NS ns-ext.nlnetlabs.nl. + 1800 RRSIG NS 8 2 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + D8Sd9JpXIOxOrUF5Hi1ASutyQwP7JNu8XZxA + rse86A6L01O8H8sCNib2VEoJjHuZ/dDEogng + OgmfqeFy04cpSX19GAk3bkx8Lr6aEat3nqIC + XA/xsCCfXy0NKZpI05zntHPbbP5tF/NvpE7n + 0+oLtlHSPEg1ZnEgwNoLe+G1jlw= ) + 1800 A 176.58.119.54 + 1800 RRSIG A 8 2 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + mrjiUFNCqDgCW8TuhjzcMh0V841uC224QvwH + 0+OvYhcve9twbX3Y12PSFmz77Xz3Jg9WAj4I + qhh3iHUac4dzUXyC702DT62yMF/9CMUO0+Ee + b6wRtvPHr2Tt0i/xV/BTbArInIvurXJrvKvo + LsZHOfsg7dZs6Mvdpe/CgwRExpk= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fe79:234c + 1800 RRSIG AAAA 8 2 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + EkMxX2vUaP4h0qbWlHaT4yNhm8MrPMZTn/3R + zNw+i3oF2cLMWKh6GCfuIX/x5ID706o8kfum + bxTYwuTe1LJ+GoZHWEiH8VCa1laTlh8l3qSi + PZKU8339rr5cCYluk6p9PbAuRkYYOEruNg42 + wPOx46dsAlvp2XpOaOeJtU64QGQ= ) + 14400 NSEC deb.atoom.net. A NS SOA AAAA RRSIG NSEC DNSKEY + 14400 RRSIG NSEC 8 2 14400 ( + 20170112031301 20161213031301 53289 atoom.net. + P7Stx7lqRKl8tbTAAaJ0W6UhgJwZz3cjpM8z + eplbhXEVohKtyJ9xgptKt1vreH6lkhzciar5 + EB9Nj0VOmcthiht/+As8aEKmf8UlcJ2EbLII + NT7NUaasxsrLE2rjjX5mEtzOZ1uQAGiU8Hnk + XdGweTgIVFuiCcMCgaKpC2TRrMw= ) + 1800 DNSKEY 256 3 8 ( + AwEAAeDZTH9YT9qLMPlq4VrxX7H3GbWcqCrC + tXc9RT/hf96GN+ttnnEQVaJY8Gbly3IZpYQW + MwaCi0t30UULXE3s9FUQtl4AMbplyiz9EF8L + /XoBS1yhGm5WV5u608ihoPaRkYNyVV3egb5Y + hA5EXWy2vfsa1XWPpxvSAhlqM0YENtP3 + ) ; ZSK; alg = RSASHA256; key id = 53289 + 1800 DNSKEY 257 3 8 ( + AwEAAepN7Vo8enDCruVduVlGxTDIv7QG0wJQ + fTL1hMy4k0Yf/7dXzrn5bZT4ytBvH1hoBImH + mtTrQo6DQlBBVXDJXTyQjQozaHpN1HhTJJTz + IXl8UrdbkLWvz6QSeJPmBBYQRAqylUA2KE29 + nxyiNboheDLiIWyQ7Q/Op7lYaKMdb555kQAs + b/XT4Tb3/3BhAjcofNofNBjDjPq2i8pAo8HU + 5mW5/Pl+ZT/S0aqQPnCkHk/iofSRu3ZdBzkH + 54eoC+BdyXb7gTbPGRr+1gMbf/rzhRiZ4vnX + NoEzGAXmorKzJHANNb6KQ/932V9UDHm9wbln + 6y3s7IBvsMX5KF8vo81Stkc= + ) ; KSK; alg = RSASHA256; key id = 19114 + 1800 RRSIG DNSKEY 8 2 1800 ( + 20170112031301 20161213031301 19114 atoom.net. + IEjViubKdef8RWB5bcnirqVcqDk16irkywJZ + sBjMyNs03/a+sl0UHEGAB7qCC+Rn+RDaM5It + WF+Gha6BwRIN9NuSg3BwB2h1nJtHw61pMVU9 + 2j9Q3pq7X1xoTBAcwY95t5a1xlw0iTCaLu1L + Iu/PbVp1gj1o8BF/PiYilvZJGUjaTgsi+YNi + 2kiWpp6afO78/W4nfVx+lQBmpyfX1lwL5PEC + 9f5PMbzRmOapvUBc2XdddGywLdmlNsLHimGV + t7kkHZHOWQR1TvvMbU3dsC0bFCrBVGDhEuxC + hATR+X5YV0AyDSyrew7fOGJKrapwMWS3yRLr + FAt0Vcxno5lwQImbCQ== ) + 1800 RRSIG DNSKEY 8 2 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + sSxdgPT+gFZPN0ot6lZRGqOwvONUEsg0uEbf + kh19JlWHu/qvq5HOOK2VOW/UnswpVmtpFk0W + z/jiCNHifjpCCVn5tfCMZDLGekmPOjdobw24 + swBuGjnn0NHvxHoN6S+mb+AR6V/dLjquNUda + yzBc2Ua+XtQ7SCLKIvEhcNg9H3o= ) +deb.atoom.net. 1800 IN A 176.58.119.54 + 1800 RRSIG A 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + ZW7jm/VDa/I9DxWlE7Cm+HHymiVv4Wk5UGYI + Uf/g0EfxLCBR6SwL5QKuV1z7xoWKaiNqqrmc + gg35xgskKyS8QHgCCODhDzcIKe+MSsBXbY04 + AtrC5dV3JJQoA65Ng/48hwcyghAjXKrA2Yyq + GXf2DSvWeIV9Jmk0CsOELP24dpk= ) + 1800 TXT "v=spf1 a ip6:2a01:7e00::f03c:91ff:fe79:234c ~all" + 1800 RRSIG TXT 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + fpvVJ+Z6tzSd9yETn/PhLSCRISwRD1c3ET80 + 8twnx3XfAPQfV2R8dw7pz8Vw4TSxvf19bAZc + PWRjW682gb7gAxoJshCXBYabMfqExrBc9V1S + ezwm3D93xNMyegxzHx2b/H8qp3ZWdsMLTvvN + Azu7P4iyO+WRWT0R7bJGrdTwRz8= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fe79:234c + 1800 RRSIG AAAA 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + aaPF6NqXfWamzi+xUDVeYa7StJUVM1tDsL34 + w5uozFRZ0f4K/Z88Kk5CgztxmtpNNKGdLWa0 + iryUJsbVWAbSQfrZNkNckBtczMNxGgjqn97A + 2//F6ajH/qrR3dWcCm+VJMgu3UPqAxLiCaYO + GQUx6Y8JA1VIM/RJAM6BhgNxjD0= ) + 14400 NSEC lafhart.atoom.net. A TXT AAAA RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20170112031301 20161213031301 53289 atoom.net. + 1Llad64NDWcz8CyBu2TsyANrJ9Tpfm5257sY + FPYF579p3c9Imwp9kYEO1zMEKgNoXBN/sQnd + YCugq3r2GAI6bfJj8sV5bt6GKuZcGHMESug4 + uh2gU0NDcCA4GPdBYGdusePwV0RNpcRnVCFA + fsACp+22j3uwRUbCh0re0ufbAs4= ) +lafhart.atoom.net. 1800 IN A 178.79.160.171 + 1800 RRSIG A 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + fruP6cvMVICXEV8NcheS73NWLCEKlO1FgW6B + 35D2GhtfYZe+M23V5YBRtlVCCrAdS0etdCOf + xH9yt3u2kVvDXuMRiQr1zJPRDEq3cScYumpd + bOO8cjHiCic5lEcRVWNNHXyGtpqTvrp9CxOu + IQw1WgAlZyKj43zGg3WZi6OTKLg= ) + 14400 NSEC linode.atoom.net. A RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20170112031301 20161213031301 53289 atoom.net. + 2AUWXbScL0jIJ7G6UsJAlUs+bgSprZ1zY6v/ + iVB5BAYwZD6pPky7LZdzvPEHh0aNLGIFbbU8 + SDJI7u/e4RUTlE+8yyjl6obZNfNKyJFqE5xN + 1BJ8sjFrVn6KaHIDKEOZunNb1MlMfCRkLg9O + 94zg04XEgVUfaYCPxvLs3fCEgzw= ) +voordeur.atoom.net. 1800 IN A 77.249.87.46 + 1800 RRSIG A 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + SzJz0NaKLRA/lW4CxgMHgeuQLp5QqFEjQv3I + zfPtY4joQsZn8RN8RLECcpcPKjbC8Dj6mxIJ + dd2vwhsCVlZKMNcZUOfpB7eGx1TR9HnzMkY9 + OdTt30a9+tktagrJEoy31vAhj1hJqLbSgvOa + pRr1P4ZpQ53/qH8JX/LOmqfWTdg= ) + 14400 NSEC www.atoom.net. A RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20170112031301 20161213031301 53289 atoom.net. + CETJhUJy1rKjVj9wsW1549gth+/Z37//BI6S + nxJ+2Oq63jEjlbznmyo5hvFW54DbVUod+cLo + N9PdlNQDr1XsRBgWhkKW37RkuoRVEPwqRykv + xzn9i7CgYKAAHFyWMGihBLkV9ByPp8GDR8Zr + DEkrG3ErDlBcwi3FqGZFsSOW2xg= ) +www.atoom.net. 1800 IN CNAME deb.atoom.net. + 1800 RRSIG CNAME 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + 1lhG6iTtbeesBCVOrA8a7+V2gogCuXzKgSi8 + 6K0Pzq2CwqTScdNcZvcDOIbLq45Am5p09PIj + lXnd2fw6WAxphwvRhmwCve3uTZMUt5STw7oi + 0rED7GMuFUSC/BX0XVly7NET3ECa1vaK6RhO + hDSsKPWFI7to4d1z6tQ9j9Kvm4Y= ) + 14400 NSEC atoom.net. CNAME RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20170112031301 20161213031301 53289 atoom.net. + CC4yCYP1q75/gTmPz+mVM6Lam2foPP5oTccY + RtROuTkgbt8DtAoPe304vmNazWBlGidnWJeD + YyAAe3znIHP0CgrxjD/hRL9FUzMnVrvB3mnx + 4W13wP1rE97RqJxV1kk22Wl3uCkVGy7LCjb0 + JLFvzCe2fuMe7YcTzI+t1rioTP0= ) +linode.atoom.net. 1800 IN A 176.58.119.54 + 1800 RRSIG A 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + Z4Ka4OLDha4eQNWs3GtUd1Cumr48RUnH523I + nZzGXtpQNou70qsm5Jt8n/HmsZ4L5DoxomRz + rgZTGnrqj43+A16UUGfVEk6SfUUHOgxgspQW + zoaqk5/5mQO1ROsLKY8RqaRqzvbToHvqeZEh + VkTPVA02JK9UFlKqoyxj72CLvkI= ) + 1800 AAAA 2a01:7e00::f03c:91ff:fe79:234c + 1800 RRSIG AAAA 8 3 1800 ( + 20170112031301 20161213031301 53289 atoom.net. + l+9Qce/EQyKrTJVKLv7iatjuCO285ckd5Oie + P2LzWVsL4tW04oHzieKZwIuNBRE+px8g5qrT + LIK2TikCGL1xHAd7CT7gbCtDcZ7jHmSTmMTJ + 405nOV3G3xWelreLI5Fn5ck8noEsF64kiw1y + XfkyQn2B914zFH/okG2fzJ1qolQ= ) + 14400 NSEC voordeur.atoom.net. A AAAA RRSIG NSEC + 14400 RRSIG NSEC 8 3 14400 ( + 20170112031301 20161213031301 53289 atoom.net. + Owzmz7QrVL2Gw2njEsUVEknMl2amx1HG9X3K + tO+Ihyy4tApiUFxUjAu3P/30QdqbB85h7s// + ipwX/AmQJNoxTScR3nHt9qDqJ044DPmiuh0l + NuIjguyZRANApmKCTA6AoxXIUqToIIjfVzi/ + PxXE6T3YIPlK7Bxgv1lcCBJ1fmE= )` + +const atoom = "atoom.net." diff --git a/plugin/file/include_test.go b/plugin/file/include_test.go new file mode 100644 index 0000000..490f05a --- /dev/null +++ b/plugin/file/include_test.go @@ -0,0 +1,31 @@ +package file + +import ( + "strings" + "testing" + + "github.com/coredns/coredns/plugin/test" +) + +// Make sure the external miekg/dns dependency is up to date + +func TestInclude(t *testing.T) { + name, rm, err := test.TempFile(".", "foo\tIN\tA\t127.0.0.1\n") + if err != nil { + t.Fatalf("Unable to create tmpfile %q: %s", name, err) + } + defer rm() + + zone := `$ORIGIN example.org. +@ IN SOA sns.dns.icann.org. noc.dns.icann.org. 2017042766 7200 3600 1209600 3600 +$INCLUDE ` + name + "\n" + + z, err := Parse(strings.NewReader(zone), "example.org.", "test", 0) + if err != nil { + t.Errorf("Unable to parse zone %q: %s", "example.org.", err) + } + + if _, ok := z.Search("foo.example.org."); !ok { + t.Errorf("Failed to find %q in parsed zone", "foo.example.org.") + } +} diff --git a/plugin/file/log_test.go b/plugin/file/log_test.go new file mode 100644 index 0000000..c9609ee --- /dev/null +++ b/plugin/file/log_test.go @@ -0,0 +1,5 @@ +package file + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/file/lookup.go b/plugin/file/lookup.go new file mode 100644 index 0000000..3f69299 --- /dev/null +++ b/plugin/file/lookup.go @@ -0,0 +1,435 @@ +package file + +import ( + "context" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin/file/rrutil" + "github.com/coredns/coredns/plugin/file/tree" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Result is the result of a Lookup +type Result int + +const ( + // Success is a successful lookup. + Success Result = iota + // NameError indicates a nameerror + NameError + // Delegation indicates the lookup resulted in a delegation. + Delegation + // NoData indicates the lookup resulted in a NODATA. + NoData + // ServerFailure indicates a server failure during the lookup. + ServerFailure +) + +// Lookup looks up qname and qtype in the zone. When do is true DNSSEC records are included. +// Three sets of records are returned, one for the answer, one for authority and one for the additional section. +func (z *Zone) Lookup(ctx context.Context, state request.Request, qname string) ([]dns.RR, []dns.RR, []dns.RR, Result) { + qtype := state.QType() + do := state.Do() + + // If z is a secondary zone we might not have transferred it, meaning we have + // all zone context setup, except the actual record. This means (for one thing) the apex + // is empty and we don't have a SOA record. + z.RLock() + ap := z.Apex + tr := z.Tree + z.RUnlock() + if ap.SOA == nil { + return nil, nil, nil, ServerFailure + } + + if qname == z.origin { + switch qtype { + case dns.TypeSOA: + return ap.soa(do), ap.ns(do), nil, Success + case dns.TypeNS: + nsrrs := ap.ns(do) + glue := tr.Glue(nsrrs, do) // technically this isn't glue + return nsrrs, nil, glue, Success + } + } + + var ( + found, shot bool + parts string + i int + elem, wildElem *tree.Elem + ) + + loop, _ := ctx.Value(dnsserver.LoopKey{}).(int) + if loop > 8 { + // We're back here for the 9th time; we have a loop and need to bail out. + // Note the answer we're returning will be incomplete (more cnames to be followed) or + // illegal (wildcard cname with multiple identical records). For now it's more important + // to protect ourselves then to give the client a valid answer. We return with an error + // to let the server handle what to do. + return nil, nil, nil, ServerFailure + } + + // Lookup: + // * Per label from the right, look if it exists. We do this to find potential + // delegation records. + // * If the per-label search finds nothing, we will look for the wildcard at the + // level. If found we keep it around. If we don't find the complete name we will + // use the wildcard. + // + // Main for-loop handles delegation and finding or not finding the qname. + // If found we check if it is a CNAME/DNAME and do CNAME processing + // We also check if we have type and do a nodata response. + // + // If not found, we check the potential wildcard, and use that for further processing. + // If not found and no wildcard we will process this as an NXDOMAIN response. + for { + parts, shot = z.nameFromRight(qname, i) + // We overshot the name, break and check if we previously found something. + if shot { + break + } + + elem, found = tr.Search(parts) + if !found { + // Apex will always be found, when we are here we can search for a wildcard + // and save the result of that search. So when nothing match, but we have a + // wildcard we should expand the wildcard. + + wildcard := replaceWithAsteriskLabel(parts) + if wild, found := tr.Search(wildcard); found { + wildElem = wild + } + + // Keep on searching, because maybe we hit an empty-non-terminal (which aren't + // stored in the tree. Only when we have match the full qname (and possible wildcard + // we can be confident that we didn't find anything. + i++ + continue + } + + // If we see DNAME records, we should return those. + if dnamerrs := elem.Type(dns.TypeDNAME); dnamerrs != nil { + // Only one DNAME is allowed per name. We just pick the first one to synthesize from. + dname := dnamerrs[0] + if cname := synthesizeCNAME(state.Name(), dname.(*dns.DNAME)); cname != nil { + var ( + answer, ns, extra []dns.RR + rcode Result + ) + + // We don't need to chase CNAME chain for synthesized CNAME + if qtype == dns.TypeCNAME { + answer = []dns.RR{cname} + ns = ap.ns(do) + extra = nil + rcode = Success + } else { + ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1) + answer, ns, extra, rcode = z.externalLookup(ctx, state, elem, []dns.RR{cname}) + } + + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, dns.TypeDNAME) + dnamerrs = append(dnamerrs, sigs...) + } + + // The relevant DNAME RR should be included in the answer section, + // if the DNAME is being employed as a substitution instruction. + answer = append(dnamerrs, answer...) + + return answer, ns, extra, rcode + } + // The domain name that owns a DNAME record is allowed to have other RR types + // at that domain name, except those have restrictions on what they can coexist + // with (e.g. another DNAME). So there is nothing special left here. + } + + // If we see NS records, it means the name as been delegated, and we should return the delegation. + if nsrrs := elem.Type(dns.TypeNS); nsrrs != nil { + // If the query is specifically for DS and the qname matches the delegated name, we should + // return the DS in the answer section and leave the rest empty, i.e. just continue the loop + // and continue searching. + if qtype == dns.TypeDS && elem.Name() == qname { + i++ + continue + } + + glue := tr.Glue(nsrrs, do) + if do { + dss := typeFromElem(elem, dns.TypeDS, do) + nsrrs = append(nsrrs, dss...) + } + + return nil, nsrrs, glue, Delegation + } + + i++ + } + + // What does found and !shot mean - do we ever hit it? + if found && !shot { + return nil, nil, nil, ServerFailure + } + + // Found entire name. + if found && shot { + if rrs := elem.Type(dns.TypeCNAME); len(rrs) > 0 && qtype != dns.TypeCNAME { + ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1) + return z.externalLookup(ctx, state, elem, rrs) + } + + rrs := elem.Type(qtype) + + // NODATA + if len(rrs) == 0 { + ret := ap.soa(do) + if do { + nsec := typeFromElem(elem, dns.TypeNSEC, do) + ret = append(ret, nsec...) + } + return nil, ret, nil, NoData + } + + // Additional section processing for MX, SRV. Check response and see if any of the names are in bailiwick - + // if so add IP addresses to the additional section. + additional := z.additionalProcessing(rrs, do) + + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, qtype) + rrs = append(rrs, sigs...) + } + + return rrs, ap.ns(do), additional, Success + } + + // Haven't found the original name. + + // Found wildcard. + if wildElem != nil { + // set metadata value for the wildcard record that synthesized the result + metadata.SetValueFunc(ctx, "zone/wildcard", func() string { + return wildElem.Name() + }) + + if rrs := wildElem.TypeForWildcard(dns.TypeCNAME, qname); len(rrs) > 0 && qtype != dns.TypeCNAME { + ctx = context.WithValue(ctx, dnsserver.LoopKey{}, loop+1) + return z.externalLookup(ctx, state, wildElem, rrs) + } + + rrs := wildElem.TypeForWildcard(qtype, qname) + + // NODATA response. + if len(rrs) == 0 { + ret := ap.soa(do) + if do { + nsec := typeFromElem(wildElem, dns.TypeNSEC, do) + ret = append(ret, nsec...) + } + return nil, ret, nil, NoData + } + + auth := ap.ns(do) + if do { + // An NSEC is needed to say no longer name exists under this wildcard. + if deny, found := tr.Prev(qname); found { + nsec := typeFromElem(deny, dns.TypeNSEC, do) + auth = append(auth, nsec...) + } + + sigs := wildElem.TypeForWildcard(dns.TypeRRSIG, qname) + sigs = rrutil.SubTypeSignature(sigs, qtype) + rrs = append(rrs, sigs...) + } + return rrs, auth, nil, Success + } + + rcode := NameError + + // Hacky way to get around empty-non-terminals. If a longer name does exist, but this qname, does not, it + // must be an empty-non-terminal. If so, we do the proper NXDOMAIN handling, but set the rcode to be success. + if x, found := tr.Next(qname); found { + if dns.IsSubDomain(qname, x.Name()) { + rcode = Success + } + } + + ret := ap.soa(do) + if do { + deny, found := tr.Prev(qname) + if !found { + goto Out + } + nsec := typeFromElem(deny, dns.TypeNSEC, do) + ret = append(ret, nsec...) + + if rcode != NameError { + goto Out + } + + ce, found := z.ClosestEncloser(qname) + + // wildcard denial only for NXDOMAIN + if found { + // wildcard denial + wildcard := "*." + ce.Name() + if ss, found := tr.Prev(wildcard); found { + // Only add this nsec if it is different than the one already added + if ss.Name() != deny.Name() { + nsec := typeFromElem(ss, dns.TypeNSEC, do) + ret = append(ret, nsec...) + } + } + } + } +Out: + return nil, ret, nil, rcode +} + +// typeFromElem returns the type tp from e and adds signatures (if they exist) and do is true. +func typeFromElem(elem *tree.Elem, tp uint16, do bool) []dns.RR { + rrs := elem.Type(tp) + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, tp) + rrs = append(rrs, sigs...) + } + return rrs +} + +func (a Apex) soa(do bool) []dns.RR { + if do { + ret := append([]dns.RR{a.SOA}, a.SIGSOA...) + return ret + } + return []dns.RR{a.SOA} +} + +func (a Apex) ns(do bool) []dns.RR { + if do { + ret := append(a.NS, a.SIGNS...) + return ret + } + return a.NS +} + +// externalLookup adds signatures and tries to resolve CNAMEs that point to external names. +func (z *Zone) externalLookup(ctx context.Context, state request.Request, elem *tree.Elem, rrs []dns.RR) ([]dns.RR, []dns.RR, []dns.RR, Result) { + qtype := state.QType() + do := state.Do() + + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, dns.TypeCNAME) + rrs = append(rrs, sigs...) + } + + targetName := rrs[0].(*dns.CNAME).Target + elem, _ = z.Tree.Search(targetName) + if elem == nil { + lookupRRs, result := z.doLookup(ctx, state, targetName, qtype) + rrs = append(rrs, lookupRRs...) + return rrs, z.Apex.ns(do), nil, result + } + + i := 0 + +Redo: + cname := elem.Type(dns.TypeCNAME) + if len(cname) > 0 { + rrs = append(rrs, cname...) + + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, dns.TypeCNAME) + rrs = append(rrs, sigs...) + } + targetName := cname[0].(*dns.CNAME).Target + elem, _ = z.Tree.Search(targetName) + if elem == nil { + lookupRRs, result := z.doLookup(ctx, state, targetName, qtype) + rrs = append(rrs, lookupRRs...) + return rrs, z.Apex.ns(do), nil, result + } + + i++ + if i > 8 { + return rrs, z.Apex.ns(do), nil, Success + } + + goto Redo + } + + targets := elem.Type(qtype) + if len(targets) > 0 { + rrs = append(rrs, targets...) + + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, qtype) + rrs = append(rrs, sigs...) + } + } + + return rrs, z.Apex.ns(do), nil, Success +} + +func (z *Zone) doLookup(ctx context.Context, state request.Request, target string, qtype uint16) ([]dns.RR, Result) { + m, e := z.Upstream.Lookup(ctx, state, target, qtype) + if e != nil { + return nil, ServerFailure + } + if m == nil { + return nil, Success + } + if m.Rcode == dns.RcodeNameError { + return m.Answer, NameError + } + if m.Rcode == dns.RcodeServerFailure { + return m.Answer, ServerFailure + } + if m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0 { + return m.Answer, NoData + } + return m.Answer, Success +} + +// additionalProcessing checks the current answer section and retrieves A or AAAA records +// (and possible SIGs) to need to be put in the additional section. +func (z *Zone) additionalProcessing(answer []dns.RR, do bool) (extra []dns.RR) { + for _, rr := range answer { + name := "" + switch x := rr.(type) { + case *dns.SRV: + name = x.Target + case *dns.MX: + name = x.Mx + } + if len(name) == 0 || !dns.IsSubDomain(z.origin, name) { + continue + } + + elem, _ := z.Tree.Search(name) + if elem == nil { + continue + } + + sigs := elem.Type(dns.TypeRRSIG) + for _, addr := range []uint16{dns.TypeA, dns.TypeAAAA} { + if a := elem.Type(addr); a != nil { + extra = append(extra, a...) + if do { + sig := rrutil.SubTypeSignature(sigs, addr) + extra = append(extra, sig...) + } + } + } + } + + return extra +} diff --git a/plugin/file/lookup_test.go b/plugin/file/lookup_test.go new file mode 100644 index 0000000..79e5604 --- /dev/null +++ b/plugin/file/lookup_test.go @@ -0,0 +1,284 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var dnsTestCases = []test.Case{ + { + Qname: "www.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("a.miek.nl. 1800 IN A 139.162.196.78"), + test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."), + }, + Ns: miekAuth, + }, + { + Qname: "www.miek.nl.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{ + test.AAAA("a.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."), + }, + Ns: miekAuth, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeSOA, + Answer: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + Ns: miekAuth, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{ + test.AAAA("miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + Ns: miekAuth, + }, + { + Qname: "mIeK.NL.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{ + test.AAAA("miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + Ns: miekAuth, + }, + { + Qname: "miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("miek.nl. 1800 IN MX 1 aspmx.l.google.com."), + test.MX("miek.nl. 1800 IN MX 10 aspmx2.googlemail.com."), + test.MX("miek.nl. 1800 IN MX 10 aspmx3.googlemail.com."), + test.MX("miek.nl. 1800 IN MX 5 alt1.aspmx.l.google.com."), + test.MX("miek.nl. 1800 IN MX 5 alt2.aspmx.l.google.com."), + }, + Ns: miekAuth, + }, + { + Qname: "a.miek.nl.", Qtype: dns.TypeSRV, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "b.miek.nl.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "srv.miek.nl.", Qtype: dns.TypeSRV, + Answer: []dns.RR{ + test.SRV("srv.miek.nl. 1800 IN SRV 10 10 8080 a.miek.nl."), + }, + Extra: []dns.RR{ + test.A("a.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + Ns: miekAuth, + }, + { + Qname: "mx.miek.nl.", Qtype: dns.TypeMX, + Answer: []dns.RR{ + test.MX("mx.miek.nl. 1800 IN MX 10 a.miek.nl."), + }, + Extra: []dns.RR{ + test.A("a.miek.nl. 1800 IN A 139.162.196.78"), + test.AAAA("a.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735"), + }, + Ns: miekAuth, + }, + { + Qname: "asterisk.x.miek.nl.", Qtype: dns.TypeCNAME, + Answer: []dns.RR{ + test.CNAME("asterisk.x.miek.nl. 1800 IN CNAME www.miek.nl."), + }, + Ns: miekAuth, + }, + { + Qname: "a.b.x.miek.nl.", Qtype: dns.TypeCNAME, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400"), + }, + }, + { + Qname: "asterisk.y.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("asterisk.y.miek.nl. 1800 IN A 139.162.196.78"), + }, + Ns: miekAuth, + }, + { + Qname: "foo.dname.miek.nl.", Qtype: dns.TypeCNAME, + Answer: []dns.RR{ + test.DNAME("dname.miek.nl. 1800 IN DNAME x.miek.nl."), + test.CNAME("foo.dname.miek.nl. 1800 IN CNAME foo.x.miek.nl."), + }, + Ns: miekAuth, + }, + { + Qname: "ext-cname.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("ext-cname.miek.nl. 1800 IN CNAME example.com."), + }, + Rcode: dns.RcodeServerFailure, + Ns: miekAuth, + }, + { + Qname: "txt.miek.nl.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`txt.miek.nl. 1800 IN TXT "v=spf1 a mx ~all"`), + }, + Ns: miekAuth, + }, + { + Qname: "caa.miek.nl.", Qtype: dns.TypeCAA, + Answer: []dns.RR{ + test.CAA(`caa.miek.nl. 1800 IN CAA 0 issue letsencrypt.org`), + }, + Ns: miekAuth, + }, +} + +const ( + testzone = "miek.nl." + testzone1 = "dnssex.nl." +) + +func TestLookup(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + + for _, tc := range dnsTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +func TestLookupNil(t *testing.T) { + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: nil}, Names: []string{testzone}}} + ctx := context.TODO() + + m := dnsTestCases[0].Msg() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + fm.ServeDNS(ctx, rec, m) +} + +func TestLookUpNoDataResult(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + var noDataTestCases = []test.Case{ + { + Qname: "a.miek.nl.", Qtype: dns.TypeMX, + }, + { + Qname: "wildcard.nodata.miek.nl.", Qtype: dns.TypeMX, + }, + } + + for _, tc := range noDataTestCases { + m := tc.Msg() + state := request.Request{W: &test.ResponseWriter{}, Req: m} + _, _, _, result := fm.Z[testzone].Lookup(ctx, state, tc.Qname) + if result != NoData { + t.Errorf("Expected result == 3 but result == %v ", result) + } + } +} + +func BenchmarkFileLookup(b *testing.B) { + zone, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin", 0) + if err != nil { + return + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + tc := test.Case{ + Qname: "www.miek.nl.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.CNAME("www.miek.nl. 1800 IN CNAME a.miek.nl."), + test.A("a.miek.nl. 1800 IN A 139.162.196.78"), + }, + } + + m := tc.Msg() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + fm.ServeDNS(ctx, rec, m) + } +} + +const dbMiekNL = ` +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + IN NS ns-ext.nlnetlabs.nl. + IN NS omval.tednet.nl. + IN NS ext.ns.whyscream.net. + + IN MX 1 aspmx.l.google.com. + IN MX 5 alt1.aspmx.l.google.com. + IN MX 5 alt2.aspmx.l.google.com. + IN MX 10 aspmx2.googlemail.com. + IN MX 10 aspmx3.googlemail.com. + + IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 + +a IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +www IN CNAME a +archive IN CNAME a +*.x IN CNAME www +b.x IN CNAME a +*.y IN A 139.162.196.78 +dname IN DNAME x + +srv IN SRV 10 10 8080 a.miek.nl. +mx IN MX 10 a.miek.nl. + +txt IN TXT "v=spf1 a mx ~all" +caa IN CAA 0 issue letsencrypt.org +*.nodata IN A 139.162.196.79 +ext-cname IN CNAME example.com.` diff --git a/plugin/file/notify.go b/plugin/file/notify.go new file mode 100644 index 0000000..7d4e35c --- /dev/null +++ b/plugin/file/notify.go @@ -0,0 +1,33 @@ +package file + +import ( + "net" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// isNotify checks if state is a notify message and if so, will *also* check if it +// is from one of the configured masters. If not it will not be a valid notify +// message. If the zone z is not a secondary zone the message will also be ignored. +func (z *Zone) isNotify(state request.Request) bool { + if state.Req.Opcode != dns.OpcodeNotify { + return false + } + if len(z.TransferFrom) == 0 { + return false + } + // If remote IP matches we accept. + remote := state.IP() + for _, f := range z.TransferFrom { + from, _, err := net.SplitHostPort(f) + if err != nil { + continue + } + if from == remote { + return true + } + } + return false +} diff --git a/plugin/file/nsec3_test.go b/plugin/file/nsec3_test.go new file mode 100644 index 0000000..ed9f74f --- /dev/null +++ b/plugin/file/nsec3_test.go @@ -0,0 +1,28 @@ +package file + +import ( + "strings" + "testing" +) + +func TestParseNSEC3PARAM(t *testing.T) { + _, err := Parse(strings.NewReader(nsec3paramTest), "miek.nl", "stdin", 0) + if err == nil { + t.Fatalf("Expected error when reading zone, got nothing") + } +} + +func TestParseNSEC3(t *testing.T) { + _, err := Parse(strings.NewReader(nsec3Test), "miek.nl", "stdin", 0) + if err == nil { + t.Fatalf("Expected error when reading zone, got nothing") + } +} + +const nsec3paramTest = `miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1460175181 14400 3600 604800 14400 +miek.nl. 1800 IN NS omval.tednet.nl. +miek.nl. 0 IN NSEC3PARAM 1 0 5 A3DEBC9CC4F695C7` + +const nsec3Test = `example.org. 1800 IN SOA sns.dns.icann.org. noc.dns.icann.org. 2016082508 7200 3600 1209600 3600 +aub8v9ce95ie18spjubsr058h41n7pa5.example.org. 284 IN NSEC3 1 1 5 D0CBEAAF0AC77314 AUB95P93VPKP55G6U5S4SGS7LS61ND85 NS SOA TXT RRSIG DNSKEY NSEC3PARAM +aub8v9ce95ie18spjubsr058h41n7pa5.example.org. 284 IN RRSIG NSEC3 8 2 600 20160910232502 20160827231002 14028 example.org. XBNpA7KAIjorPbXvTinOHrc1f630aHic2U716GHLHA4QMx9cl9ss4QjR Wj2UpDM9zBW/jNYb1xb0yjQoez/Jv200w0taSWjRci5aUnRpOi9bmcrz STHb6wIUjUsbJ+NstQsUwVkj6679UviF1FqNwr4GlJnWG3ZrhYhE+NI6 s0k=` diff --git a/plugin/file/reload.go b/plugin/file/reload.go new file mode 100644 index 0000000..cdb50f4 --- /dev/null +++ b/plugin/file/reload.go @@ -0,0 +1,69 @@ +package file + +import ( + "os" + "path/filepath" + "time" + + "github.com/coredns/coredns/plugin/transfer" +) + +// Reload reloads a zone when it is changed on disk. If z.ReloadInterval is zero, no reloading will be done. +func (z *Zone) Reload(t *transfer.Transfer) error { + if z.ReloadInterval == 0 { + return nil + } + tick := time.NewTicker(z.ReloadInterval) + + go func() { + for { + select { + case <-tick.C: + zFile := z.File() + reader, err := os.Open(filepath.Clean(zFile)) + if err != nil { + log.Errorf("Failed to open zone %q in %q: %v", z.origin, zFile, err) + continue + } + + serial := z.SOASerialIfDefined() + zone, err := Parse(reader, z.origin, zFile, serial) + reader.Close() + if err != nil { + if _, ok := err.(*serialErr); !ok { + log.Errorf("Parsing zone %q: %v", z.origin, err) + } + continue + } + + // copy elements we need + z.Lock() + z.Apex = zone.Apex + z.Tree = zone.Tree + z.Unlock() + + log.Infof("Successfully reloaded zone %q in %q with %d SOA serial", z.origin, zFile, z.Apex.SOA.Serial) + if t != nil { + if err := t.Notify(z.origin); err != nil { + log.Warningf("Failed sending notifies: %s", err) + } + } + + case <-z.reloadShutdown: + tick.Stop() + return + } + } + }() + return nil +} + +// SOASerialIfDefined returns the SOA's serial if the zone has a SOA record in the Apex, or -1 otherwise. +func (z *Zone) SOASerialIfDefined() int64 { + z.RLock() + defer z.RUnlock() + if z.Apex.SOA != nil { + return int64(z.Apex.SOA.Serial) + } + return -1 +} diff --git a/plugin/file/reload_test.go b/plugin/file/reload_test.go new file mode 100644 index 0000000..c404bc4 --- /dev/null +++ b/plugin/file/reload_test.go @@ -0,0 +1,90 @@ +package file + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/plugin/transfer" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestZoneReload(t *testing.T) { + fileName, rm, err := test.TempFile(".", reloadZoneTest) + if err != nil { + t.Fatalf("Failed to create zone: %s", err) + } + defer rm() + reader, err := os.Open(fileName) + if err != nil { + t.Fatalf("Failed to open zone: %s", err) + } + z, err := Parse(reader, "miek.nl", fileName, 0) + if err != nil { + t.Fatalf("Failed to parse zone: %s", err) + } + + z.ReloadInterval = 10 * time.Millisecond + z.Reload(&transfer.Transfer{}) + time.Sleep(20 * time.Millisecond) + + ctx := context.TODO() + r := new(dns.Msg) + r.SetQuestion("miek.nl", dns.TypeSOA) + state := request.Request{W: &test.ResponseWriter{}, Req: r} + if _, _, _, res := z.Lookup(ctx, state, "miek.nl."); res != Success { + t.Fatalf("Failed to lookup, got %d", res) + } + + r = new(dns.Msg) + r.SetQuestion("miek.nl", dns.TypeNS) + state = request.Request{W: &test.ResponseWriter{}, Req: r} + if _, _, _, res := z.Lookup(ctx, state, "miek.nl."); res != Success { + t.Fatalf("Failed to lookup, got %d", res) + } + + rrs, err := z.ApexIfDefined() // all apex records. + if err != nil { + t.Fatal(err) + } + if len(rrs) != 5 { + t.Fatalf("Expected 5 RRs, got %d", len(rrs)) + } + if err := os.WriteFile(fileName, []byte(reloadZone2Test), 0644); err != nil { + t.Fatalf("Failed to write new zone data: %s", err) + } + // Could still be racy, but we need to wait a bit for the event to be seen + time.Sleep(30 * time.Millisecond) + + rrs, err = z.ApexIfDefined() + if err != nil { + t.Fatal(err) + } + if len(rrs) != 3 { + t.Fatalf("Expected 3 RRs, got %d", len(rrs)) + } +} + +func TestZoneReloadSOAChange(t *testing.T) { + _, err := Parse(strings.NewReader(reloadZoneTest), "miek.nl.", "stdin", 1460175181) + if err == nil { + t.Fatalf("Zone should not have been re-parsed") + } +} + +const reloadZoneTest = `miek.nl. 1627 IN SOA linode.atoom.net. miek.miek.nl. 1460175181 14400 3600 604800 14400 +miek.nl. 1627 IN NS ext.ns.whyscream.net. +miek.nl. 1627 IN NS omval.tednet.nl. +miek.nl. 1627 IN NS linode.atoom.net. +miek.nl. 1627 IN NS ns-ext.nlnetlabs.nl. +` + +const reloadZone2Test = `miek.nl. 1627 IN SOA linode.atoom.net. miek.miek.nl. 1460175182 14400 3600 604800 14400 +miek.nl. 1627 IN NS ext.ns.whyscream.net. +miek.nl. 1627 IN NS omval.tednet.nl. +` diff --git a/plugin/file/rrutil/util.go b/plugin/file/rrutil/util.go new file mode 100644 index 0000000..564b82c --- /dev/null +++ b/plugin/file/rrutil/util.go @@ -0,0 +1,18 @@ +// Package rrutil provides function to find certain RRs in slices. +package rrutil + +import "github.com/miekg/dns" + +// SubTypeSignature returns the RRSIG for the subtype. +func SubTypeSignature(rrs []dns.RR, subtype uint16) []dns.RR { + sigs := []dns.RR{} + // there may be multiple keys that have signed this subtype + for _, sig := range rrs { + if s, ok := sig.(*dns.RRSIG); ok { + if s.TypeCovered == subtype { + sigs = append(sigs, s) + } + } + } + return sigs +} diff --git a/plugin/file/secondary.go b/plugin/file/secondary.go new file mode 100644 index 0000000..932916b --- /dev/null +++ b/plugin/file/secondary.go @@ -0,0 +1,198 @@ +package file + +import ( + "math/rand" + "time" + + "github.com/miekg/dns" +) + +// TransferIn retrieves the zone from the masters, parses it and sets it live. +func (z *Zone) TransferIn() error { + if len(z.TransferFrom) == 0 { + return nil + } + m := new(dns.Msg) + m.SetAxfr(z.origin) + + z1 := z.CopyWithoutApex() + var ( + Err error + tr string + ) + +Transfer: + for _, tr = range z.TransferFrom { + t := new(dns.Transfer) + c, err := t.In(m, tr) + if err != nil { + log.Errorf("Failed to setup transfer `%s' with `%q': %v", z.origin, tr, err) + Err = err + continue Transfer + } + for env := range c { + if env.Error != nil { + log.Errorf("Failed to transfer `%s' from %q: %v", z.origin, tr, env.Error) + Err = env.Error + continue Transfer + } + for _, rr := range env.RR { + if err := z1.Insert(rr); err != nil { + log.Errorf("Failed to parse transfer `%s' from: %q: %v", z.origin, tr, err) + Err = err + continue Transfer + } + } + } + Err = nil + break + } + if Err != nil { + return Err + } + + z.Lock() + z.Tree = z1.Tree + z.Apex = z1.Apex + z.Expired = false + z.Unlock() + log.Infof("Transferred: %s from %s", z.origin, tr) + return nil +} + +// shouldTransfer checks the primaries of zone, retrieves the SOA record, checks the current serial +// and the remote serial and will return true if the remote one is higher than the locally configured one. +func (z *Zone) shouldTransfer() (bool, error) { + c := new(dns.Client) + c.Net = "tcp" // do this query over TCP to minimize spoofing + m := new(dns.Msg) + m.SetQuestion(z.origin, dns.TypeSOA) + + var Err error + serial := -1 + +Transfer: + for _, tr := range z.TransferFrom { + Err = nil + ret, _, err := c.Exchange(m, tr) + if err != nil || ret.Rcode != dns.RcodeSuccess { + Err = err + continue + } + for _, a := range ret.Answer { + if a.Header().Rrtype == dns.TypeSOA { + serial = int(a.(*dns.SOA).Serial) + break Transfer + } + } + } + if serial == -1 { + return false, Err + } + if z.Apex.SOA == nil { + return true, Err + } + return less(z.Apex.SOA.Serial, uint32(serial)), Err +} + +// less returns true of a is smaller than b when taking RFC 1982 serial arithmetic into account. +func less(a, b uint32) bool { + if a < b { + return (b - a) <= MaxSerialIncrement + } + return (a - b) > MaxSerialIncrement +} + +// Update updates the secondary zone according to its SOA. It will run for the life time of the server +// and uses the SOA parameters. Every refresh it will check for a new SOA number. If that fails (for all +// server) it will retry every retry interval. If the zone failed to transfer before the expire, the zone +// will be marked expired. +func (z *Zone) Update() error { + // If we don't have a SOA, we don't have a zone, wait for it to appear. + for z.Apex.SOA == nil { + time.Sleep(1 * time.Second) + } + retryActive := false + +Restart: + refresh := time.Second * time.Duration(z.Apex.SOA.Refresh) + retry := time.Second * time.Duration(z.Apex.SOA.Retry) + expire := time.Second * time.Duration(z.Apex.SOA.Expire) + + refreshTicker := time.NewTicker(refresh) + retryTicker := time.NewTicker(retry) + expireTicker := time.NewTicker(expire) + + for { + select { + case <-expireTicker.C: + if !retryActive { + break + } + z.Expired = true + + case <-retryTicker.C: + if !retryActive { + break + } + + time.Sleep(jitter(2000)) // 2s randomize + + ok, err := z.shouldTransfer() + if err != nil { + log.Warningf("Failed retry check %s", err) + continue + } + + if ok { + if err := z.TransferIn(); err != nil { + // transfer failed, leave retryActive true + break + } + } + + // no errors, stop timers and restart + retryActive = false + refreshTicker.Stop() + retryTicker.Stop() + expireTicker.Stop() + goto Restart + + case <-refreshTicker.C: + + time.Sleep(jitter(5000)) // 5s randomize + + ok, err := z.shouldTransfer() + if err != nil { + log.Warningf("Failed refresh check %s", err) + retryActive = true + continue + } + + if ok { + if err := z.TransferIn(); err != nil { + // transfer failed + retryActive = true + break + } + } + + // no errors, stop timers and restart + retryActive = false + refreshTicker.Stop() + retryTicker.Stop() + expireTicker.Stop() + goto Restart + } + } +} + +// jitter returns a random duration between [0,n) * time.Millisecond +func jitter(n int) time.Duration { + r := rand.Intn(n) + return time.Duration(r) * time.Millisecond +} + +// MaxSerialIncrement is the maximum difference between two serial numbers. If the difference between +// two serials is greater than this number, the smaller one is considered greater. +const MaxSerialIncrement uint32 = 2147483647 diff --git a/plugin/file/secondary_test.go b/plugin/file/secondary_test.go new file mode 100644 index 0000000..67d151e --- /dev/null +++ b/plugin/file/secondary_test.go @@ -0,0 +1,146 @@ +package file + +import ( + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestLess(t *testing.T) { + const ( + min = 0 + max = 4294967295 + low = 12345 + high = 4000000000 + ) + + if less(min, max) { + t.Fatalf("Less: should be false") + } + if !less(max, min) { + t.Fatalf("Less: should be true") + } + if !less(high, low) { + t.Fatalf("Less: should be true") + } + if !less(7, 9) { + t.Fatalf("Less; should be true") + } +} + +type soa struct { + serial uint32 +} + +func (s *soa) Handler(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + switch req.Question[0].Qtype { + case dns.TypeSOA: + m.Answer = make([]dns.RR, 1) + m.Answer[0] = test.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + w.WriteMsg(m) + case dns.TypeAXFR: + m.Answer = make([]dns.RR, 4) + m.Answer[0] = test.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + m.Answer[1] = test.A(fmt.Sprintf("%s IN A 127.0.0.1", testZone)) + m.Answer[2] = test.A(fmt.Sprintf("%s IN A 127.0.0.1", testZone)) + m.Answer[3] = test.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + w.WriteMsg(m) + } +} + +func (s *soa) TransferHandler(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + m.Answer = make([]dns.RR, 1) + m.Answer[0] = test.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, s.serial)) + w.WriteMsg(m) +} + +const testZone = "secondary.miek.nl." + +func TestShouldTransfer(t *testing.T) { + soa := soa{250} + + s := dnstest.NewServer(soa.Handler) + defer s.Close() + + z := NewZone("testzone", "test") + z.origin = testZone + z.TransferFrom = []string{s.Addr} + + // when we have a nil SOA (initial state) + should, err := z.shouldTransfer() + if err != nil { + t.Fatalf("Unable to run shouldTransfer: %v", err) + } + if !should { + t.Fatalf("ShouldTransfer should return true for serial: %d", soa.serial) + } + // Serial smaller + z.Apex.SOA = test.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, soa.serial-1)) + should, err = z.shouldTransfer() + if err != nil { + t.Fatalf("Unable to run shouldTransfer: %v", err) + } + if !should { + t.Fatalf("ShouldTransfer should return true for serial: %q", soa.serial-1) + } + // Serial equal + z.Apex.SOA = test.SOA(fmt.Sprintf("%s IN SOA bla. bla. %d 0 0 0 0 ", testZone, soa.serial)) + should, err = z.shouldTransfer() + if err != nil { + t.Fatalf("Unable to run shouldTransfer: %v", err) + } + if should { + t.Fatalf("ShouldTransfer should return false for serial: %d", soa.serial) + } +} + +func TestTransferIn(t *testing.T) { + soa := soa{250} + + s := dnstest.NewServer(soa.Handler) + defer s.Close() + + z := new(Zone) + z.origin = testZone + z.TransferFrom = []string{s.Addr} + + if err := z.TransferIn(); err != nil { + t.Fatalf("Unable to run TransferIn: %v", err) + } + if z.Apex.SOA.String() != fmt.Sprintf("%s 3600 IN SOA bla. bla. 250 0 0 0 0", testZone) { + t.Fatalf("Unknown SOA transferred") + } +} + +func TestIsNotify(t *testing.T) { + z := new(Zone) + z.origin = testZone + state := newRequest(testZone, dns.TypeSOA) + // need to set opcode + state.Req.Opcode = dns.OpcodeNotify + + z.TransferFrom = []string{"10.240.0.1:53"} // IP from testing/responseWriter + if !z.isNotify(state) { + t.Fatal("Should have been valid notify") + } + z.TransferFrom = []string{"10.240.0.2:53"} + if z.isNotify(state) { + t.Fatal("Should have been invalid notify") + } +} + +func newRequest(zone string, qtype uint16) request.Request { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.SetEdns0(4097, true) + return request.Request{W: &test.ResponseWriter{}, Req: m} +} diff --git a/plugin/file/setup.go b/plugin/file/setup.go new file mode 100644 index 0000000..73a2a23 --- /dev/null +++ b/plugin/file/setup.go @@ -0,0 +1,153 @@ +package file + +import ( + "errors" + "os" + "path/filepath" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/plugin/transfer" +) + +func init() { plugin.Register("file", setup) } + +func setup(c *caddy.Controller) error { + zones, err := fileParse(c) + if err != nil { + return plugin.Error("file", err) + } + + f := File{Zones: zones} + // get the transfer plugin, so we can send notifies and send notifies on startup as well. + c.OnStartup(func() error { + t := dnsserver.GetConfig(c).Handler("transfer") + if t == nil { + return nil + } + f.transfer = t.(*transfer.Transfer) // if found this must be OK. + go func() { + for _, n := range zones.Names { + f.transfer.Notify(n) + } + }() + return nil + }) + + c.OnRestartFailed(func() error { + t := dnsserver.GetConfig(c).Handler("transfer") + if t == nil { + return nil + } + go func() { + for _, n := range zones.Names { + f.transfer.Notify(n) + } + }() + return nil + }) + + for _, n := range zones.Names { + z := zones.Z[n] + c.OnShutdown(z.OnShutdown) + c.OnStartup(func() error { + z.StartupOnce.Do(func() { z.Reload(f.transfer) }) + return nil + }) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + f.Next = next + return f + }) + + return nil +} + +func fileParse(c *caddy.Controller) (Zones, error) { + z := make(map[string]*Zone) + names := []string{} + + config := dnsserver.GetConfig(c) + + var openErr error + reload := 1 * time.Minute + + for c.Next() { + // file db.file [zones...] + if !c.NextArg() { + return Zones{}, c.ArgErr() + } + fileName := c.Val() + + origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + if !filepath.IsAbs(fileName) && config.Root != "" { + fileName = filepath.Join(config.Root, fileName) + } + + reader, err := os.Open(filepath.Clean(fileName)) + if err != nil { + openErr = err + } + + err = func() error { + defer reader.Close() + + for i := range origins { + z[origins[i]] = NewZone(origins[i], fileName) + if openErr == nil { + reader.Seek(0, 0) + zone, err := Parse(reader, origins[i], fileName, 0) + if err != nil { + return err + } + z[origins[i]] = zone + } + names = append(names, origins[i]) + } + return nil + }() + + if err != nil { + return Zones{}, err + } + + for c.NextBlock() { + switch c.Val() { + case "reload": + t := c.RemainingArgs() + if len(t) < 1 { + return Zones{}, errors.New("reload duration value is expected") + } + d, err := time.ParseDuration(t[0]) + if err != nil { + return Zones{}, plugin.Error("file", err) + } + reload = d + case "upstream": + // remove soon + c.RemainingArgs() + + default: + return Zones{}, c.Errf("unknown property '%s'", c.Val()) + } + } + + for i := range origins { + z[origins[i]].ReloadInterval = reload + z[origins[i]].Upstream = upstream.New() + } + } + + if openErr != nil { + if reload == 0 { + // reload hasn't been set make this a fatal error + return Zones{}, plugin.Error("file", openErr) + } + log.Warningf("Failed to open %q: trying again in %s", openErr, reload) + } + return Zones{Z: z, Names: names}, nil +} diff --git a/plugin/file/setup_test.go b/plugin/file/setup_test.go new file mode 100644 index 0000000..1d3b8dc --- /dev/null +++ b/plugin/file/setup_test.go @@ -0,0 +1,124 @@ +package file + +import ( + "testing" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/test" +) + +func TestFileParse(t *testing.T) { + zoneFileName1, rm, err := test.TempFile(".", dbMiekNL) + if err != nil { + t.Fatal(err) + } + defer rm() + + zoneFileName2, rm, err := test.TempFile(".", dbDnssexNLSigned) + if err != nil { + t.Fatal(err) + } + defer rm() + + tests := []struct { + inputFileRules string + shouldErr bool + expectedZones Zones + }{ + { + `file ` + zoneFileName1 + ` miek.nl.`, + false, + Zones{Names: []string{"miek.nl."}}, + }, + { + `file ` + zoneFileName2 + ` dnssex.nl.`, + false, + Zones{Names: []string{"dnssex.nl."}}, + }, + { + `file ` + zoneFileName2 + ` 10.0.0.0/8`, + false, + Zones{Names: []string{"10.in-addr.arpa."}}, + }, + // errors. + { + `file ` + zoneFileName1 + ` miek.nl { + transfer from 127.0.0.1 + }`, + true, + Zones{}, + }, + { + `file`, + true, + Zones{}, + }, + { + `file ` + zoneFileName1 + ` example.net. { + no_reload + }`, + true, + Zones{}, + }, + { + `file ` + zoneFileName1 + ` example.net. { + no_rebloat + }`, + true, + Zones{}, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + actualZones, err := fileParse(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } else { + if len(actualZones.Names) != len(test.expectedZones.Names) { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedZones.Names, actualZones.Names) + } + for j, name := range test.expectedZones.Names { + if actualZones.Names[j] != name { + t.Fatalf("Test %d expected %v for %d th zone, got %v", i, name, j, actualZones.Names[j]) + } + } + } + } +} + +func TestParseReload(t *testing.T) { + name, rm, err := test.TempFile(".", dbMiekNL) + if err != nil { + t.Fatal(err) + } + defer rm() + + tests := []struct { + input string + reload time.Duration + }{ + { + `file ` + name + ` example.org.`, + 1 * time.Minute, + }, + { + `file ` + name + ` example.org. { + reload 5s + }`, + 5 * time.Second, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + z, _ := fileParse(c) + if x := z.Z["example.org."].ReloadInterval; x != test.reload { + t.Errorf("Test %d expected reload to be %s, but got %s", i, test.reload, x) + } + } +} diff --git a/plugin/file/shutdown.go b/plugin/file/shutdown.go new file mode 100644 index 0000000..9aa5989 --- /dev/null +++ b/plugin/file/shutdown.go @@ -0,0 +1,9 @@ +package file + +// OnShutdown shuts down any running go-routines for this zone. +func (z *Zone) OnShutdown() error { + if 0 < z.ReloadInterval { + z.reloadShutdown <- true + } + return nil +} diff --git a/plugin/file/tree/all.go b/plugin/file/tree/all.go new file mode 100644 index 0000000..e1fc5b3 --- /dev/null +++ b/plugin/file/tree/all.go @@ -0,0 +1,21 @@ +package tree + +// All traverses tree and returns all elements. +func (t *Tree) All() []*Elem { + if t.Root == nil { + return nil + } + found := t.Root.all(nil) + return found +} + +func (n *Node) all(found []*Elem) []*Elem { + if n.Left != nil { + found = n.Left.all(found) + } + found = append(found, n.Elem) + if n.Right != nil { + found = n.Right.all(found) + } + return found +} diff --git a/plugin/file/tree/auth_walk.go b/plugin/file/tree/auth_walk.go new file mode 100644 index 0000000..1f43671 --- /dev/null +++ b/plugin/file/tree/auth_walk.go @@ -0,0 +1,58 @@ +package tree + +import ( + "github.com/miekg/dns" +) + +// AuthWalk performs fn on all authoritative values stored in the tree in +// pre-order depth first. If a non-nil error is returned the AuthWalk was interrupted +// by an fn returning that error. If fn alters stored values' sort +// relationships, future tree operation behaviors are undefined. +// +// The fn function will be called with 3 arguments, the current element, a map containing all +// the RRs for this element and a boolean if this name is considered authoritative. +func (t *Tree) AuthWalk(fn func(*Elem, map[uint16][]dns.RR, bool) error) error { + if t.Root == nil { + return nil + } + return t.Root.authwalk(make(map[string]struct{}), fn) +} + +func (n *Node) authwalk(ns map[string]struct{}, fn func(*Elem, map[uint16][]dns.RR, bool) error) error { + if n.Left != nil { + if err := n.Left.authwalk(ns, fn); err != nil { + return err + } + } + + // Check if the current name is a subdomain of *any* of the delegated names we've seen, if so, skip this name. + // The ordering of the tree and how we walk if guarantees we see parents first. + if n.Elem.Type(dns.TypeNS) != nil { + ns[n.Elem.Name()] = struct{}{} + } + + auth := true + i := 0 + for { + j, end := dns.NextLabel(n.Elem.Name(), i) + if end { + break + } + if _, ok := ns[n.Elem.Name()[j:]]; ok { + auth = false + break + } + i++ + } + + if err := fn(n.Elem, n.Elem.m, auth); err != nil { + return err + } + + if n.Right != nil { + if err := n.Right.authwalk(ns, fn); err != nil { + return err + } + } + return nil +} diff --git a/plugin/file/tree/elem.go b/plugin/file/tree/elem.go new file mode 100644 index 0000000..c190964 --- /dev/null +++ b/plugin/file/tree/elem.go @@ -0,0 +1,101 @@ +package tree + +import "github.com/miekg/dns" + +// Elem is an element in the tree. +type Elem struct { + m map[uint16][]dns.RR + name string // owner name +} + +// newElem returns a new elem. +func newElem(rr dns.RR) *Elem { + e := Elem{m: make(map[uint16][]dns.RR)} + e.m[rr.Header().Rrtype] = []dns.RR{rr} + return &e +} + +// Types returns the types of the records in e. The returned list is not sorted. +func (e *Elem) Types() []uint16 { + t := make([]uint16, len(e.m)) + i := 0 + for ty := range e.m { + t[i] = ty + i++ + } + return t +} + +// Type returns the RRs with type qtype from e. +func (e *Elem) Type(qtype uint16) []dns.RR { return e.m[qtype] } + +// TypeForWildcard returns the RRs with type qtype from e. The ownername returned is set to qname. +func (e *Elem) TypeForWildcard(qtype uint16, qname string) []dns.RR { + rrs := e.m[qtype] + + if rrs == nil { + return nil + } + + copied := make([]dns.RR, len(rrs)) + for i := range rrs { + copied[i] = dns.Copy(rrs[i]) + copied[i].Header().Name = qname + } + return copied +} + +// All returns all RRs from e, regardless of type. +func (e *Elem) All() []dns.RR { + list := []dns.RR{} + for _, rrs := range e.m { + list = append(list, rrs...) + } + return list +} + +// Name returns the name for this node. +func (e *Elem) Name() string { + if e.name != "" { + return e.name + } + for _, rrs := range e.m { + e.name = rrs[0].Header().Name + return e.name + } + return "" +} + +// Empty returns true is e does not contain any RRs, i.e. is an empty-non-terminal. +func (e *Elem) Empty() bool { return len(e.m) == 0 } + +// Insert inserts rr into e. If rr is equal to existing RRs, the RR will be added anyway. +func (e *Elem) Insert(rr dns.RR) { + t := rr.Header().Rrtype + if e.m == nil { + e.m = make(map[uint16][]dns.RR) + e.m[t] = []dns.RR{rr} + return + } + rrs, ok := e.m[t] + if !ok { + e.m[t] = []dns.RR{rr} + return + } + + rrs = append(rrs, rr) + e.m[t] = rrs +} + +// Delete removes all RRs of type rr.Header().Rrtype from e. +func (e *Elem) Delete(rr dns.RR) { + if e.m == nil { + return + } + + t := rr.Header().Rrtype + delete(e.m, t) +} + +// Less is a tree helper function that calls less. +func Less(a *Elem, name string) int { return less(name, a.Name()) } diff --git a/plugin/file/tree/glue.go b/plugin/file/tree/glue.go new file mode 100644 index 0000000..937ae54 --- /dev/null +++ b/plugin/file/tree/glue.go @@ -0,0 +1,44 @@ +package tree + +import ( + "github.com/coredns/coredns/plugin/file/rrutil" + + "github.com/miekg/dns" +) + +// Glue returns any potential glue records for nsrrs. +func (t *Tree) Glue(nsrrs []dns.RR, do bool) []dns.RR { + glue := []dns.RR{} + for _, rr := range nsrrs { + if ns, ok := rr.(*dns.NS); ok && dns.IsSubDomain(ns.Header().Name, ns.Ns) { + glue = append(glue, t.searchGlue(ns.Ns, do)...) + } + } + return glue +} + +// searchGlue looks up A and AAAA for name. +func (t *Tree) searchGlue(name string, do bool) []dns.RR { + glue := []dns.RR{} + + // A + if elem, found := t.Search(name); found { + glue = append(glue, elem.Type(dns.TypeA)...) + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, dns.TypeA) + glue = append(glue, sigs...) + } + } + + // AAAA + if elem, found := t.Search(name); found { + glue = append(glue, elem.Type(dns.TypeAAAA)...) + if do { + sigs := elem.Type(dns.TypeRRSIG) + sigs = rrutil.SubTypeSignature(sigs, dns.TypeAAAA) + glue = append(glue, sigs...) + } + } + return glue +} diff --git a/plugin/file/tree/less.go b/plugin/file/tree/less.go new file mode 100644 index 0000000..7421cf0 --- /dev/null +++ b/plugin/file/tree/less.go @@ -0,0 +1,59 @@ +package tree + +import ( + "bytes" + + "github.com/miekg/dns" +) + +// less returns <0 when a is less than b, 0 when they are equal and +// >0 when a is larger than b. +// The function orders names in DNSSEC canonical order: RFC 4034s section-6.1 +// +// See https://bert-hubert.blogspot.co.uk/2015/10/how-to-do-fast-canonical-ordering-of.html +// for a blog article on this implementation, although here we still go label by label. +// +// The values of a and b are *not* lowercased before the comparison! +func less(a, b string) int { + i := 1 + aj := len(a) + bj := len(b) + for { + ai, oka := dns.PrevLabel(a, i) + bi, okb := dns.PrevLabel(b, i) + if oka && okb { + return 0 + } + + // sadly this []byte will allocate... TODO(miek): check if this is needed + // for a name, otherwise compare the strings. + ab := []byte(a[ai:aj]) + bb := []byte(b[bi:bj]) + doDDD(ab) + doDDD(bb) + + res := bytes.Compare(ab, bb) + if res != 0 { + return res + } + + i++ + aj, bj = ai, bi + } +} + +func doDDD(b []byte) { + lb := len(b) + for i := 0; i < lb; i++ { + if i+3 < lb && b[i] == '\\' && isDigit(b[i+1]) && isDigit(b[i+2]) && isDigit(b[i+3]) { + b[i] = dddToByte(b[i:]) + for j := i + 1; j < lb-3; j++ { + b[j] = b[j+3] + } + lb -= 3 + } + } +} + +func isDigit(b byte) bool { return b >= '0' && b <= '9' } +func dddToByte(s []byte) byte { return (s[1]-'0')*100 + (s[2]-'0')*10 + (s[3] - '0') } diff --git a/plugin/file/tree/less_test.go b/plugin/file/tree/less_test.go new file mode 100644 index 0000000..f2559de --- /dev/null +++ b/plugin/file/tree/less_test.go @@ -0,0 +1,80 @@ +package tree + +import ( + "sort" + "strings" + "testing" +) + +type set []string + +func (p set) Len() int { return len(p) } +func (p set) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p set) Less(i, j int) bool { d := less(p[i], p[j]); return d <= 0 } + +func TestLess(t *testing.T) { + tests := []struct { + in []string + out []string + }{ + { + []string{"aaa.powerdns.de", "bbb.powerdns.net.", "xxx.powerdns.com."}, + []string{"xxx.powerdns.com.", "aaa.powerdns.de", "bbb.powerdns.net."}, + }, + { + []string{"aaa.POWERDNS.de", "bbb.PoweRdnS.net.", "xxx.powerdns.com."}, + []string{"xxx.powerdns.com.", "aaa.POWERDNS.de", "bbb.PoweRdnS.net."}, + }, + { + []string{"aaa.aaaa.aa.", "aa.aaa.a.", "bbb.bbbb.bb."}, + []string{"aa.aaa.a.", "aaa.aaaa.aa.", "bbb.bbbb.bb."}, + }, + { + []string{"aaaaa.", "aaa.", "bbb."}, + []string{"aaa.", "aaaaa.", "bbb."}, + }, + { + []string{"a.a.a.a.", "a.a.", "a.a.a."}, + []string{"a.a.", "a.a.a.", "a.a.a.a."}, + }, + { + []string{"example.", "z.example.", "a.example."}, + []string{"example.", "a.example.", "z.example."}, + }, + { + []string{"a.example.", "Z.a.example.", "z.example.", "yljkjljk.a.example.", "\\001.z.example.", "example.", "*.z.example.", "\\200.z.example.", "zABC.a.EXAMPLE."}, + []string{"example.", "a.example.", "yljkjljk.a.example.", "Z.a.example.", "zABC.a.EXAMPLE.", "z.example.", "\\001.z.example.", "*.z.example.", "\\200.z.example."}, + }, + { + // RFC3034 example. + []string{"a.example.", "Z.a.example.", "z.example.", "yljkjljk.a.example.", "example.", "*.z.example.", "zABC.a.EXAMPLE."}, + []string{"example.", "a.example.", "yljkjljk.a.example.", "Z.a.example.", "zABC.a.EXAMPLE.", "z.example.", "*.z.example."}, + }, + } + +Tests: + for j, test := range tests { + // Need to lowercase these example as the Less function does lowercase for us anymore. + for i, b := range test.in { + test.in[i] = strings.ToLower(b) + } + for i, b := range test.out { + test.out[i] = strings.ToLower(b) + } + + sort.Sort(set(test.in)) + for i := 0; i < len(test.in); i++ { + if test.in[i] != test.out[i] { + t.Errorf("Test %d: expected %s, got %s", j, test.out[i], test.in[i]) + n := "" + for k, in := range test.in { + if k+1 == len(test.in) { + n = "\n" + } + t.Logf("%s <-> %s\n%s", in, test.out[k], n) + } + continue Tests + } + } + } +} diff --git a/plugin/file/tree/print.go b/plugin/file/tree/print.go new file mode 100644 index 0000000..b2df70e --- /dev/null +++ b/plugin/file/tree/print.go @@ -0,0 +1,62 @@ +package tree + +import "fmt" + +// Print prints a Tree. Main use is to aid in debugging. +func (t *Tree) Print() { + if t.Root == nil { + fmt.Println("<nil>") + } + t.Root.print() +} + +func (n *Node) print() { + q := newQueue() + q.push(n) + + nodesInCurrentLevel := 1 + nodesInNextLevel := 0 + + for !q.empty() { + do := q.pop() + nodesInCurrentLevel-- + + if do != nil { + fmt.Print(do.Elem.Name(), " ") + q.push(do.Left) + q.push(do.Right) + nodesInNextLevel += 2 + } + if nodesInCurrentLevel == 0 { + fmt.Println() + nodesInCurrentLevel = nodesInNextLevel + nodesInNextLevel = 0 + } + } + fmt.Println() +} + +type queue []*Node + +// newQueue returns a new queue. +func newQueue() queue { + q := queue([]*Node{}) + return q +} + +// push pushes n to the end of the queue. +func (q *queue) push(n *Node) { + *q = append(*q, n) +} + +// pop pops the first element off the queue. +func (q *queue) pop() *Node { + n := (*q)[0] + *q = (*q)[1:] + return n +} + +// empty returns true when the queue contains zero nodes. +func (q *queue) empty() bool { + return len(*q) == 0 +} diff --git a/plugin/file/tree/print_test.go b/plugin/file/tree/print_test.go new file mode 100644 index 0000000..20ad37d --- /dev/null +++ b/plugin/file/tree/print_test.go @@ -0,0 +1,100 @@ +package tree + +import ( + "net" + "os" + "strings" + "testing" + + "github.com/miekg/dns" +) + +func TestPrint(t *testing.T) { + rr1 := dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn("server1.example.com"), + Rrtype: 1, + Class: 1, + Ttl: 3600, + Rdlength: 0, + }, + A: net.IPv4(10, 0, 1, 1), + } + rr2 := dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn("server2.example.com"), + Rrtype: 1, + Class: 1, + Ttl: 3600, + Rdlength: 0, + }, + A: net.IPv4(10, 0, 1, 2), + } + rr3 := dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn("server3.example.com"), + Rrtype: 1, + Class: 1, + Ttl: 3600, + Rdlength: 0, + }, + A: net.IPv4(10, 0, 1, 3), + } + rr4 := dns.A{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn("server4.example.com"), + Rrtype: 1, + Class: 1, + Ttl: 3600, + Rdlength: 0, + }, + A: net.IPv4(10, 0, 1, 4), + } + tree := Tree{ + Root: nil, + Count: 0, + } + tree.Insert(&rr1) + tree.Insert(&rr2) + tree.Insert(&rr3) + tree.Insert(&rr4) + + /** + build a LLRB tree, the height of the tree is 3, look like: + + server2.example.com. + / \ + server1.example.com. server4.example.com. + / + server3.example.com. + + */ + + f, err := os.CreateTemp("", "print_test_tmp") + if err != nil { + t.Error(err) + } + defer os.Remove(f.Name()) + //Redirect the printed results to a tmp file for later comparison + os.Stdout = f + + tree.Print() + /** + server2.example.com. + server1.example.com. server4.example.com. + server3.example.com. + */ + + buf := make([]byte, 256) + f.Seek(0, 0) + _, err = f.Read(buf) + if err != nil { + f.Close() + t.Error(err) + } + height := strings.Count(string(buf), ". \n") + //Compare the height of the print with the actual height of the tree + if height != 3 { + t.Fatal("The number of rows is inconsistent with the actual number of rows in the tree itself.") + } +} diff --git a/plugin/file/tree/tree.go b/plugin/file/tree/tree.go new file mode 100644 index 0000000..a6caafe --- /dev/null +++ b/plugin/file/tree/tree.go @@ -0,0 +1,453 @@ +// Copyright ©2012 The bíogo Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found at the end of this file. + +// Package tree implements Left-Leaning Red Black trees as described by Robert Sedgewick. +// +// More details relating to the implementation are available at the following locations: +// +// http://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf +// http://www.cs.princeton.edu/~rs/talks/LLRB/Java/RedBlackBST.java +// http://www.teachsolaisgames.com/articles/balanced_left_leaning.html +// +// Heavily modified by Miek Gieben for use in DNS zones. +package tree + +import "github.com/miekg/dns" + +const ( + td234 = iota + bu23 +) + +// Operation mode of the LLRB tree. +const mode = bu23 + +func init() { + if mode != td234 && mode != bu23 { + panic("tree: unknown mode") + } +} + +// A Color represents the color of a Node. +type Color bool + +const ( + // Red as false give us the defined behaviour that new nodes are red. Although this + // is incorrect for the root node, that is resolved on the first insertion. + red Color = false + black Color = true +) + +// A Node represents a node in the LLRB tree. +type Node struct { + Elem *Elem + Left, Right *Node + Color Color +} + +// A Tree manages the root node of an LLRB tree. Public methods are exposed through this type. +type Tree struct { + Root *Node // Root node of the tree. + Count int // Number of elements stored. +} + +// Helper methods + +// color returns the effect color of a Node. A nil node returns black. +func (n *Node) color() Color { + if n == nil { + return black + } + return n.Color +} + +// (a,c)b -rotL-> ((a,)b,)c +func (n *Node) rotateLeft() (root *Node) { + // Assumes: n has two children. + root = n.Right + n.Right = root.Left + root.Left = n + root.Color = n.Color + n.Color = red + return +} + +// (a,c)b -rotR-> (,(,c)b)a +func (n *Node) rotateRight() (root *Node) { + // Assumes: n has two children. + root = n.Left + n.Left = root.Right + root.Right = n + root.Color = n.Color + n.Color = red + return +} + +// (aR,cR)bB -flipC-> (aB,cB)bR | (aB,cB)bR -flipC-> (aR,cR)bB +func (n *Node) flipColors() { + // Assumes: n has two children. + n.Color = !n.Color + n.Left.Color = !n.Left.Color + n.Right.Color = !n.Right.Color +} + +// fixUp ensures that black link balance is correct, that red nodes lean left, +// and that 4 nodes are split in the case of BU23 and properly balanced in TD234. +func (n *Node) fixUp() *Node { + if n.Right.color() == red { + if mode == td234 && n.Right.Left.color() == red { + n.Right = n.Right.rotateRight() + } + n = n.rotateLeft() + } + if n.Left.color() == red && n.Left.Left.color() == red { + n = n.rotateRight() + } + if mode == bu23 && n.Left.color() == red && n.Right.color() == red { + n.flipColors() + } + return n +} + +func (n *Node) moveRedLeft() *Node { + n.flipColors() + if n.Right.Left.color() == red { + n.Right = n.Right.rotateRight() + n = n.rotateLeft() + n.flipColors() + if mode == td234 && n.Right.Right.color() == red { + n.Right = n.Right.rotateLeft() + } + } + return n +} + +func (n *Node) moveRedRight() *Node { + n.flipColors() + if n.Left.Left.color() == red { + n = n.rotateRight() + n.flipColors() + } + return n +} + +// Len returns the number of elements stored in the Tree. +func (t *Tree) Len() int { + return t.Count +} + +// Search returns the first match of qname in the Tree. +func (t *Tree) Search(qname string) (*Elem, bool) { + if t.Root == nil { + return nil, false + } + n, res := t.Root.search(qname) + if n == nil { + return nil, res + } + return n.Elem, res +} + +// search searches the tree for qname and type. +func (n *Node) search(qname string) (*Node, bool) { + for n != nil { + switch c := Less(n.Elem, qname); { + case c == 0: + return n, true + case c < 0: + n = n.Left + default: + n = n.Right + } + } + + return n, false +} + +// Insert inserts rr into the Tree at the first match found +// with e or when a nil node is reached. +func (t *Tree) Insert(rr dns.RR) { + var d int + t.Root, d = t.Root.insert(rr) + t.Count += d + t.Root.Color = black +} + +// insert inserts rr in to the tree. +func (n *Node) insert(rr dns.RR) (root *Node, d int) { + if n == nil { + return &Node{Elem: newElem(rr)}, 1 + } else if n.Elem == nil { + n.Elem = newElem(rr) + return n, 1 + } + + if mode == td234 { + if n.Left.color() == red && n.Right.color() == red { + n.flipColors() + } + } + + switch c := Less(n.Elem, rr.Header().Name); { + case c == 0: + n.Elem.Insert(rr) + case c < 0: + n.Left, d = n.Left.insert(rr) + default: + n.Right, d = n.Right.insert(rr) + } + + if n.Right.color() == red && n.Left.color() == black { + n = n.rotateLeft() + } + if n.Left.color() == red && n.Left.Left.color() == red { + n = n.rotateRight() + } + + if mode == bu23 { + if n.Left.color() == red && n.Right.color() == red { + n.flipColors() + } + } + + root = n + + return +} + +// DeleteMin deletes the node with the minimum value in the tree. +func (t *Tree) DeleteMin() { + if t.Root == nil { + return + } + var d int + t.Root, d = t.Root.deleteMin() + t.Count += d + if t.Root == nil { + return + } + t.Root.Color = black +} + +func (n *Node) deleteMin() (root *Node, d int) { + if n.Left == nil { + return nil, -1 + } + if n.Left.color() == black && n.Left.Left.color() == black { + n = n.moveRedLeft() + } + n.Left, d = n.Left.deleteMin() + + root = n.fixUp() + + return +} + +// DeleteMax deletes the node with the maximum value in the tree. +func (t *Tree) DeleteMax() { + if t.Root == nil { + return + } + var d int + t.Root, d = t.Root.deleteMax() + t.Count += d + if t.Root == nil { + return + } + t.Root.Color = black +} + +func (n *Node) deleteMax() (root *Node, d int) { + if n.Left != nil && n.Left.color() == red { + n = n.rotateRight() + } + if n.Right == nil { + return nil, -1 + } + if n.Right.color() == black && n.Right.Left.color() == black { + n = n.moveRedRight() + } + n.Right, d = n.Right.deleteMax() + + root = n.fixUp() + + return +} + +// Delete removes all RRs of type rr.Header().Rrtype from e. If after the deletion of rr the node is empty the +// entire node is deleted. +func (t *Tree) Delete(rr dns.RR) { + if t.Root == nil { + return + } + + el, _ := t.Search(rr.Header().Name) + if el == nil { + return + } + el.Delete(rr) + if el.Empty() { + t.deleteNode(rr) + } +} + +// DeleteNode deletes the node that matches rr according to Less(). +func (t *Tree) deleteNode(rr dns.RR) { + if t.Root == nil { + return + } + var d int + t.Root, d = t.Root.delete(rr) + t.Count += d + if t.Root == nil { + return + } + t.Root.Color = black +} + +func (n *Node) delete(rr dns.RR) (root *Node, d int) { + if Less(n.Elem, rr.Header().Name) < 0 { + if n.Left != nil { + if n.Left.color() == black && n.Left.Left.color() == black { + n = n.moveRedLeft() + } + n.Left, d = n.Left.delete(rr) + } + } else { + if n.Left.color() == red { + n = n.rotateRight() + } + if n.Right == nil && Less(n.Elem, rr.Header().Name) == 0 { + return nil, -1 + } + if n.Right != nil { + if n.Right.color() == black && n.Right.Left.color() == black { + n = n.moveRedRight() + } + if Less(n.Elem, rr.Header().Name) == 0 { + n.Elem = n.Right.min().Elem + n.Right, d = n.Right.deleteMin() + } else { + n.Right, d = n.Right.delete(rr) + } + } + } + + root = n.fixUp() + return +} + +// Min returns the minimum value stored in the tree. +func (t *Tree) Min() *Elem { + if t.Root == nil { + return nil + } + return t.Root.min().Elem +} + +func (n *Node) min() *Node { + for ; n.Left != nil; n = n.Left { + } + return n +} + +// Max returns the maximum value stored in the tree. +func (t *Tree) Max() *Elem { + if t.Root == nil { + return nil + } + return t.Root.max().Elem +} + +func (n *Node) max() *Node { + for ; n.Right != nil; n = n.Right { + } + return n +} + +// Prev returns the greatest value equal to or less than the qname according to Less(). +func (t *Tree) Prev(qname string) (*Elem, bool) { + if t.Root == nil { + return nil, false + } + + n := t.Root.floor(qname) + if n == nil { + return nil, false + } + return n.Elem, true +} + +func (n *Node) floor(qname string) *Node { + if n == nil { + return nil + } + switch c := Less(n.Elem, qname); { + case c == 0: + return n + case c <= 0: + return n.Left.floor(qname) + default: + if r := n.Right.floor(qname); r != nil { + return r + } + } + return n +} + +// Next returns the smallest value equal to or greater than the qname according to Less(). +func (t *Tree) Next(qname string) (*Elem, bool) { + if t.Root == nil { + return nil, false + } + n := t.Root.ceil(qname) + if n == nil { + return nil, false + } + return n.Elem, true +} + +func (n *Node) ceil(qname string) *Node { + if n == nil { + return nil + } + switch c := Less(n.Elem, qname); { + case c == 0: + return n + case c > 0: + return n.Right.ceil(qname) + default: + if l := n.Left.ceil(qname); l != nil { + return l + } + } + return n +} + +/* +Copyright ©2012 The bíogo Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +* Neither the name of the bíogo project nor the names of its authors and + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ diff --git a/plugin/file/tree/walk.go b/plugin/file/tree/walk.go new file mode 100644 index 0000000..e315eb0 --- /dev/null +++ b/plugin/file/tree/walk.go @@ -0,0 +1,33 @@ +package tree + +import "github.com/miekg/dns" + +// Walk performs fn on all authoritative values stored in the tree in +// in-order depth first. If a non-nil error is returned the Walk was interrupted +// by an fn returning that error. If fn alters stored values' sort +// relationships, future tree operation behaviors are undefined. +func (t *Tree) Walk(fn func(*Elem, map[uint16][]dns.RR) error) error { + if t.Root == nil { + return nil + } + return t.Root.walk(fn) +} + +func (n *Node) walk(fn func(*Elem, map[uint16][]dns.RR) error) error { + if n.Left != nil { + if err := n.Left.walk(fn); err != nil { + return err + } + } + + if err := fn(n.Elem, n.Elem.m); err != nil { + return err + } + + if n.Right != nil { + if err := n.Right.walk(fn); err != nil { + return err + } + } + return nil +} diff --git a/plugin/file/wildcard.go b/plugin/file/wildcard.go new file mode 100644 index 0000000..7e8e806 --- /dev/null +++ b/plugin/file/wildcard.go @@ -0,0 +1,13 @@ +package file + +import "github.com/miekg/dns" + +// replaceWithAsteriskLabel replaces the left most label with '*'. +func replaceWithAsteriskLabel(qname string) (wildcard string) { + i, shot := dns.NextLabel(qname, 0) + if shot { + return "" + } + + return "*." + qname[i:] +} diff --git a/plugin/file/wildcard_test.go b/plugin/file/wildcard_test.go new file mode 100644 index 0000000..fc6ad12 --- /dev/null +++ b/plugin/file/wildcard_test.go @@ -0,0 +1,298 @@ +package file + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// these examples don't have an additional opt RR set, because that's gets added by the server. +var wildcardTestCases = []test.Case{ + { + Qname: "wild.dnssex.nl.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`wild.dnssex.nl. 1800 IN TXT "Doing It Safe Is Better"`), + }, + Ns: dnssexAuth[:len(dnssexAuth)-1], // remove RRSIG on the end + }, + { + Qname: "a.wild.dnssex.nl.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`a.wild.dnssex.nl. 1800 IN TXT "Doing It Safe Is Better"`), + }, + Ns: dnssexAuth[:len(dnssexAuth)-1], // remove RRSIG on the end + }, + { + Qname: "wild.dnssex.nl.", Qtype: dns.TypeTXT, Do: true, + Answer: []dns.RR{ + test.RRSIG("wild.dnssex.nl. 1800 IN RRSIG TXT 8 2 1800 20160428190224 20160329190224 14460 dnssex.nl. FUZSTyvZfeuuOpCm"), + test.TXT(`wild.dnssex.nl. 1800 IN TXT "Doing It Safe Is Better"`), + }, + Ns: append([]dns.RR{ + test.NSEC("a.dnssex.nl. 14400 IN NSEC www.dnssex.nl. A AAAA RRSIG NSEC"), + test.RRSIG("a.dnssex.nl. 14400 IN RRSIG NSEC 8 3 14400 20160428190224 20160329190224 14460 dnssex.nl. S+UMs2ySgRaaRY"), + }, dnssexAuth...), + }, + { + Qname: "a.wild.dnssex.nl.", Qtype: dns.TypeTXT, Do: true, + Answer: []dns.RR{ + test.RRSIG("a.wild.dnssex.nl. 1800 IN RRSIG TXT 8 2 1800 20160428190224 20160329190224 14460 dnssex.nl. FUZSTyvZfeuuOpCm"), + test.TXT(`a.wild.dnssex.nl. 1800 IN TXT "Doing It Safe Is Better"`), + }, + Ns: append([]dns.RR{ + test.NSEC("a.dnssex.nl. 14400 IN NSEC www.dnssex.nl. A AAAA RRSIG NSEC"), + test.RRSIG("a.dnssex.nl. 14400 IN RRSIG NSEC 8 3 14400 20160428190224 20160329190224 14460 dnssex.nl. S+UMs2ySgRaaRY"), + }, dnssexAuth...), + }, + // nodata responses + { + Qname: "wild.dnssex.nl.", Qtype: dns.TypeSRV, + Ns: []dns.RR{ + test.SOA(`dnssex.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1459281744 14400 3600 604800 14400`), + }, + }, + { + Qname: "wild.dnssex.nl.", Qtype: dns.TypeSRV, Do: true, + Ns: []dns.RR{ + // TODO(miek): needs closest encloser proof as well? This is the wrong answer + test.NSEC(`*.dnssex.nl. 14400 IN NSEC a.dnssex.nl. TXT RRSIG NSEC`), + test.RRSIG(`*.dnssex.nl. 14400 IN RRSIG NSEC 8 2 14400 20160428190224 20160329190224 14460 dnssex.nl. os6INm6q2eXknD5z8TaaDOV+Ge/Ko+2dXnKP+J1fqJzafXJVH1F0nDrcXmMlR6jlBHA=`), + test.RRSIG(`dnssex.nl. 1800 IN RRSIG SOA 8 2 1800 20160428190224 20160329190224 14460 dnssex.nl. CA/Y3m9hCOiKC/8ieSOv8SeP964Bq++lyH8BZJcTaabAsERs4xj5PRtcxicwQXZiF8fYUCpROlUS0YR8Cdw=`), + test.SOA(`dnssex.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1459281744 14400 3600 604800 14400`), + }, + }, +} + +var dnssexAuth = []dns.RR{ + test.NS("dnssex.nl. 1800 IN NS linode.atoom.net."), + test.NS("dnssex.nl. 1800 IN NS ns-ext.nlnetlabs.nl."), + test.NS("dnssex.nl. 1800 IN NS omval.tednet.nl."), + test.RRSIG("dnssex.nl. 1800 IN RRSIG NS 8 2 1800 20160428190224 20160329190224 14460 dnssex.nl. dLIeEvP86jj5ndkcLzhgvWixTABjWAGRTGQsPsVDFXsGMf9TGGC9FEomgkCVeNC0="), +} + +func TestLookupWildcard(t *testing.T) { + zone, err := Parse(strings.NewReader(dbDnssexNLSigned), testzone1, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone1: zone}, Names: []string{testzone1}}} + ctx := context.TODO() + + for _, tc := range wildcardTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +var wildcardDoubleTestCases = []test.Case{ + { + Qname: "wild.w.example.org.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`wild.w.example.org. IN TXT "Wildcard"`), + }, + Ns: exampleAuth, + }, + { + Qname: "wild.c.example.org.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`wild.c.example.org. IN TXT "c Wildcard"`), + }, + Ns: exampleAuth, + }, + { + Qname: "wild.d.example.org.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`alias.example.org. IN TXT "Wildcard CNAME expansion"`), + test.CNAME(`wild.d.example.org. IN CNAME alias.example.org`), + }, + Ns: exampleAuth, + }, + { + Qname: "alias.example.org.", Qtype: dns.TypeTXT, + Answer: []dns.RR{ + test.TXT(`alias.example.org. IN TXT "Wildcard CNAME expansion"`), + }, + Ns: exampleAuth, + }, +} + +var exampleAuth = []dns.RR{ + test.NS("example.org. 3600 IN NS a.iana-servers.net."), + test.NS("example.org. 3600 IN NS b.iana-servers.net."), +} + +func TestLookupDoubleWildcard(t *testing.T) { + zone, err := Parse(strings.NewReader(exampleOrg), "example.org.", "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{"example.org.": zone}, Names: []string{"example.org."}}} + ctx := context.TODO() + + for _, tc := range wildcardDoubleTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +func TestReplaceWithAsteriskLabel(t *testing.T) { + tests := []struct { + in, out string + }{ + {".", ""}, + {"miek.nl.", "*.nl."}, + {"www.miek.nl.", "*.miek.nl."}, + } + + for _, tc := range tests { + got := replaceWithAsteriskLabel(tc.in) + if got != tc.out { + t.Errorf("Expected to be %s, got %s", tc.out, got) + } + } +} + +var apexWildcardTestCases = []test.Case{ + { + Qname: "foo.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A(`foo.example.org. 3600 IN A 127.0.0.54`)}, + Ns: []dns.RR{test.NS(`example.org. 3600 IN NS b.iana-servers.net.`)}, + }, + { + Qname: "bar.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A(`bar.example.org. 3600 IN A 127.0.0.53`)}, + Ns: []dns.RR{test.NS(`example.org. 3600 IN NS b.iana-servers.net.`)}, + }, +} + +func TestLookupApexWildcard(t *testing.T) { + const name = "example.org." + zone, err := Parse(strings.NewReader(apexWildcard), name, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{name: zone}, Names: []string{name}}} + ctx := context.TODO() + + for _, tc := range apexWildcardTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +var multiWildcardTestCases = []test.Case{ + { + Qname: "foo.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A(`foo.example.org. 3600 IN A 127.0.0.54`)}, + Ns: []dns.RR{test.NS(`example.org. 3600 IN NS b.iana-servers.net.`)}, + }, + { + Qname: "bar.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A(`bar.example.org. 3600 IN A 127.0.0.53`)}, + Ns: []dns.RR{test.NS(`example.org. 3600 IN NS b.iana-servers.net.`)}, + }, + { + Qname: "bar.intern.example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{test.A(`bar.intern.example.org. 3600 IN A 127.0.1.52`)}, + Ns: []dns.RR{test.NS(`example.org. 3600 IN NS b.iana-servers.net.`)}, + }, +} + +func TestLookupMultiWildcard(t *testing.T) { + const name = "example.org." + zone, err := Parse(strings.NewReader(doubleWildcard), name, "stdin", 0) + if err != nil { + t.Fatalf("Expect no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{name: zone}, Names: []string{name}}} + ctx := context.TODO() + + for _, tc := range multiWildcardTestCases { + m := tc.Msg() + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + resp := rec.Msg + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +const exampleOrg = `; example.org test file +$TTL 3600 +example.org. IN SOA sns.dns.icann.org. noc.dns.icann.org. 2015082541 7200 3600 1209600 3600 +example.org. IN NS b.iana-servers.net. +example.org. IN NS a.iana-servers.net. +example.org. IN A 127.0.0.1 +example.org. IN A 127.0.0.2 +*.w.example.org. IN TXT "Wildcard" +a.b.c.w.example.org. IN TXT "Not a wildcard" +*.c.example.org. IN TXT "c Wildcard" +*.d.example.org. IN CNAME alias.example.org. +alias.example.org. IN TXT "Wildcard CNAME expansion" +` + +const apexWildcard = `; example.org test file with wildcard at apex +$TTL 3600 +example.org. IN SOA sns.dns.icann.org. noc.dns.icann.org. 2015082541 7200 3600 1209600 3600 +example.org. IN NS b.iana-servers.net. +*.example.org. IN A 127.0.0.53 +foo.example.org. IN A 127.0.0.54 +` + +const doubleWildcard = `; example.org test file with wildcard at apex +$TTL 3600 +example.org. IN SOA sns.dns.icann.org. noc.dns.icann.org. 2015082541 7200 3600 1209600 3600 +example.org. IN NS b.iana-servers.net. +*.example.org. IN A 127.0.0.53 +*.intern.example.org. IN A 127.0.1.52 +foo.example.org. IN A 127.0.0.54 +` diff --git a/plugin/file/xfr.go b/plugin/file/xfr.go new file mode 100644 index 0000000..28c3a3a --- /dev/null +++ b/plugin/file/xfr.go @@ -0,0 +1,45 @@ +package file + +import ( + "github.com/coredns/coredns/plugin/file/tree" + "github.com/coredns/coredns/plugin/transfer" + + "github.com/miekg/dns" +) + +// Transfer implements the transfer.Transfer interface. +func (f File) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + z, ok := f.Zones.Z[zone] + if !ok || z == nil { + return nil, transfer.ErrNotAuthoritative + } + return z.Transfer(serial) +} + +// Transfer transfers a zone with serial in the returned channel and implements IXFR fallback, by just +// sending a single SOA record. +func (z *Zone) Transfer(serial uint32) (<-chan []dns.RR, error) { + // get soa and apex + apex, err := z.ApexIfDefined() + if err != nil { + return nil, err + } + + ch := make(chan []dns.RR) + go func() { + if serial != 0 && apex[0].(*dns.SOA).Serial == serial { // ixfr fallback, only send SOA + ch <- []dns.RR{apex[0]} + + close(ch) + return + } + + ch <- apex + z.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error { ch <- e.All(); return nil }) + ch <- []dns.RR{apex[0]} + + close(ch) + }() + + return ch, nil +} diff --git a/plugin/file/xfr_test.go b/plugin/file/xfr_test.go new file mode 100644 index 0000000..f8d4caf --- /dev/null +++ b/plugin/file/xfr_test.go @@ -0,0 +1,72 @@ +package file + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func ExampleZone_All() { + zone, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin", 0) + if err != nil { + return + } + records := zone.All() + for _, r := range records { + fmt.Printf("%+v\n", r) + } + // Output + // xfr_test.go:15: miek.nl. 1800 IN SOA linode.atoom.net. miek.miek.nl. 1282630057 14400 3600 604800 14400 + // xfr_test.go:15: www.miek.nl. 1800 IN CNAME a.miek.nl. + // xfr_test.go:15: miek.nl. 1800 IN NS linode.atoom.net. + // xfr_test.go:15: miek.nl. 1800 IN NS ns-ext.nlnetlabs.nl. + // xfr_test.go:15: miek.nl. 1800 IN NS omval.tednet.nl. + // xfr_test.go:15: miek.nl. 1800 IN NS ext.ns.whyscream.net. + // xfr_test.go:15: miek.nl. 1800 IN MX 1 aspmx.l.google.com. + // xfr_test.go:15: miek.nl. 1800 IN MX 5 alt1.aspmx.l.google.com. + // xfr_test.go:15: miek.nl. 1800 IN MX 5 alt2.aspmx.l.google.com. + // xfr_test.go:15: miek.nl. 1800 IN MX 10 aspmx2.googlemail.com. + // xfr_test.go:15: miek.nl. 1800 IN MX 10 aspmx3.googlemail.com. + // xfr_test.go:15: miek.nl. 1800 IN A 139.162.196.78 + // xfr_test.go:15: miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 + // xfr_test.go:15: archive.miek.nl. 1800 IN CNAME a.miek.nl. + // xfr_test.go:15: a.miek.nl. 1800 IN A 139.162.196.78 + // xfr_test.go:15: a.miek.nl. 1800 IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +} + +func TestAllNewZone(t *testing.T) { + zone := NewZone("example.org.", "stdin") + records := zone.All() + if len(records) != 0 { + t.Errorf("Expected %d records in empty zone, got %d", 0, len(records)) + } +} + +func TestAXFRWithOutTransferPlugin(t *testing.T) { + zone, err := Parse(strings.NewReader(dbMiekNL), testzone, "stdin", 0) + if err != nil { + t.Fatalf("Expected no error when reading zone, got %q", err) + } + + fm := File{Next: test.ErrorHandler(), Zones: Zones{Z: map[string]*Zone{testzone: zone}, Names: []string{testzone}}} + ctx := context.TODO() + + m := new(dns.Msg) + m.SetQuestion("miek.nl.", dns.TypeAXFR) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := fm.ServeDNS(ctx, rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + if code != dns.RcodeRefused { + t.Errorf("Expecting REFUSED, got %d", code) + } +} diff --git a/plugin/file/zone.go b/plugin/file/zone.go new file mode 100644 index 0000000..aa5f3ca --- /dev/null +++ b/plugin/file/zone.go @@ -0,0 +1,178 @@ +package file + +import ( + "fmt" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin/file/tree" + "github.com/coredns/coredns/plugin/pkg/upstream" + + "github.com/miekg/dns" +) + +// Zone is a structure that contains all data related to a DNS zone. +type Zone struct { + origin string + origLen int + file string + *tree.Tree + Apex + Expired bool + + sync.RWMutex + + StartupOnce sync.Once + TransferFrom []string + + ReloadInterval time.Duration + reloadShutdown chan bool + + Upstream *upstream.Upstream // Upstream for looking up external names during the resolution process. +} + +// Apex contains the apex records of a zone: SOA, NS and their potential signatures. +type Apex struct { + SOA *dns.SOA + NS []dns.RR + SIGSOA []dns.RR + SIGNS []dns.RR +} + +// NewZone returns a new zone. +func NewZone(name, file string) *Zone { + return &Zone{ + origin: dns.Fqdn(name), + origLen: dns.CountLabel(dns.Fqdn(name)), + file: filepath.Clean(file), + Tree: &tree.Tree{}, + reloadShutdown: make(chan bool), + } +} + +// Copy copies a zone. +func (z *Zone) Copy() *Zone { + z1 := NewZone(z.origin, z.file) + z1.TransferFrom = z.TransferFrom + z1.Expired = z.Expired + + z1.Apex = z.Apex + return z1 +} + +// CopyWithoutApex copies zone z without the Apex records. +func (z *Zone) CopyWithoutApex() *Zone { + z1 := NewZone(z.origin, z.file) + z1.TransferFrom = z.TransferFrom + z1.Expired = z.Expired + + return z1 +} + +// Insert inserts r into z. +func (z *Zone) Insert(r dns.RR) error { + r.Header().Name = strings.ToLower(r.Header().Name) + + switch h := r.Header().Rrtype; h { + case dns.TypeNS: + r.(*dns.NS).Ns = strings.ToLower(r.(*dns.NS).Ns) + + if r.Header().Name == z.origin { + z.Apex.NS = append(z.Apex.NS, r) + return nil + } + case dns.TypeSOA: + r.(*dns.SOA).Ns = strings.ToLower(r.(*dns.SOA).Ns) + r.(*dns.SOA).Mbox = strings.ToLower(r.(*dns.SOA).Mbox) + + z.Apex.SOA = r.(*dns.SOA) + return nil + case dns.TypeNSEC3, dns.TypeNSEC3PARAM: + return fmt.Errorf("NSEC3 zone is not supported, dropping RR: %s for zone: %s", r.Header().Name, z.origin) + case dns.TypeRRSIG: + x := r.(*dns.RRSIG) + switch x.TypeCovered { + case dns.TypeSOA: + z.Apex.SIGSOA = append(z.Apex.SIGSOA, x) + return nil + case dns.TypeNS: + if r.Header().Name == z.origin { + z.Apex.SIGNS = append(z.Apex.SIGNS, x) + return nil + } + } + case dns.TypeCNAME: + r.(*dns.CNAME).Target = strings.ToLower(r.(*dns.CNAME).Target) + case dns.TypeMX: + r.(*dns.MX).Mx = strings.ToLower(r.(*dns.MX).Mx) + case dns.TypeSRV: + r.(*dns.SRV).Target = strings.ToLower(r.(*dns.SRV).Target) + } + + z.Tree.Insert(r) + return nil +} + +// File retrieves the file path in a safe way. +func (z *Zone) File() string { + z.RLock() + defer z.RUnlock() + return z.file +} + +// SetFile updates the file path in a safe way. +func (z *Zone) SetFile(path string) { + z.Lock() + z.file = path + z.Unlock() +} + +// ApexIfDefined returns the apex nodes from z. The SOA record is the first record, if it does not exist, an error is returned. +func (z *Zone) ApexIfDefined() ([]dns.RR, error) { + z.RLock() + defer z.RUnlock() + if z.Apex.SOA == nil { + return nil, fmt.Errorf("no SOA") + } + + rrs := []dns.RR{z.Apex.SOA} + + if len(z.Apex.SIGSOA) > 0 { + rrs = append(rrs, z.Apex.SIGSOA...) + } + if len(z.Apex.NS) > 0 { + rrs = append(rrs, z.Apex.NS...) + } + if len(z.Apex.SIGNS) > 0 { + rrs = append(rrs, z.Apex.SIGNS...) + } + + return rrs, nil +} + +// NameFromRight returns the labels from the right, staring with the +// origin and then i labels extra. When we are overshooting the name +// the returned boolean is set to true. +func (z *Zone) nameFromRight(qname string, i int) (string, bool) { + if i <= 0 { + return z.origin, false + } + + for j := 1; j <= z.origLen; j++ { + if _, shot := dns.PrevLabel(qname, j); shot { + return qname, shot + } + } + + k := 0 + var shot bool + for j := 1; j <= i; j++ { + k, shot = dns.PrevLabel(qname, j+z.origLen) + if shot { + return qname, shot + } + } + return qname[k:], false +} diff --git a/plugin/file/zone_test.go b/plugin/file/zone_test.go new file mode 100644 index 0000000..aa42fd8 --- /dev/null +++ b/plugin/file/zone_test.go @@ -0,0 +1,30 @@ +package file + +import "testing" + +func TestNameFromRight(t *testing.T) { + z := NewZone("example.org.", "stdin") + + tests := []struct { + in string + labels int + shot bool + expected string + }{ + {"example.org.", 0, false, "example.org."}, + {"a.example.org.", 0, false, "example.org."}, + {"a.example.org.", 1, false, "a.example.org."}, + {"a.example.org.", 2, true, "a.example.org."}, + {"a.b.example.org.", 2, false, "a.b.example.org."}, + } + + for i, tc := range tests { + got, shot := z.nameFromRight(tc.in, tc.labels) + if got != tc.expected { + t.Errorf("Test %d: expected %s, got %s", i, tc.expected, got) + } + if shot != tc.shot { + t.Errorf("Test %d: expected shot to be %t, got %t", i, tc.shot, shot) + } + } +} diff --git a/plugin/forward/README.md b/plugin/forward/README.md new file mode 100644 index 0000000..7dd66f7 --- /dev/null +++ b/plugin/forward/README.md @@ -0,0 +1,273 @@ +# forward + +## Name + +*forward* - facilitates proxying DNS messages to upstream resolvers. + +## Description + +The *forward* plugin re-uses already opened sockets to the upstreams. It supports UDP, TCP and +DNS-over-TLS and uses in band health checking. + +When it detects an error a health check is performed. This checks runs in a loop, performing each +check at a *0.5s* interval for as long as the upstream reports unhealthy. Once healthy we stop +health checking (until the next error). The health checks use a recursive DNS query (`. IN NS`) +to get upstream health. Any response that is not a network error (REFUSED, NOTIMPL, SERVFAIL, etc) +is taken as a healthy upstream. The health check uses the same protocol as specified in **TO**. If +`max_fails` is set to 0, no checking is performed and upstreams will always be considered healthy. + +When *all* upstreams are down it assumes health checking as a mechanism has failed and will try to +connect to a random upstream (which may or may not work). + +## Syntax + +In its most basic form, a simple forwarder uses this syntax: + +~~~ +forward FROM TO... +~~~ + +* **FROM** is the base domain to match for the request to be forwarded. Domains using CIDR notation + that expand to multiple reverse zones are not fully supported; only the first expanded zone is used. +* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify + a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for plain DNS. The number of upstreams is + limited to 15. + +Multiple upstreams are randomized (see `policy`) on first use. When a healthy proxy returns an error +during the exchange the next upstream in the list is tried. + +Extra knobs are available with an expanded syntax: + +~~~ +forward FROM TO... { + except IGNORED_NAMES... + force_tcp + prefer_udp + expire DURATION + max_fails INTEGER + tls CERT KEY CA + tls_servername NAME + policy random|round_robin|sequential + health_check DURATION [no_rec] [domain FQDN] + max_concurrent MAX +} +~~~ + +* **FROM** and **TO...** as above. +* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. + Requests that match none of these names will be passed through. +* `force_tcp`, use TCP even when the request comes in over UDP. +* `prefer_udp`, try first using UDP even when the request comes in over TCP. If response is truncated + (TC flag set in response) then do another attempt over TCP. In case if both `force_tcp` and + `prefer_udp` options specified the `force_tcp` takes precedence. +* `max_fails` is the number of subsequent failed health checks that are needed before considering + an upstream to be down. If 0, the upstream will never be marked as down (nor health checked). + Default is 2. +* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. +* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be + provided with the meaning as described below + + * `tls` - no client authentication is used, and the system CAs are used to verify the server certificate + * `tls` **CA** - no client authentication is used, and the file CA is used to verify the server certificate + * `tls` **CERT** **KEY** - client authentication is used with the specified cert/key pair. + The server certificate is verified with the system CAs + * `tls` **CERT** **KEY** **CA** - client authentication is used with the specified cert/key pair. + The server certificate is verified using the specified CA file + +* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 + needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario, + but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1 + (Cloudflare) will not work. Using TLS forwarding but not setting `tls_servername` results in anyone + being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this, + it is strongly recommended to set this value when using TLS forwarding. +* `policy` specifies the policy to use for selecting upstream servers. The default is `random`. + * `random` is a policy that implements random upstream selection. + * `round_robin` is a policy that selects hosts based on round robin ordering. + * `sequential` is a policy that selects hosts based on sequential ordering. +* `health_check` configure the behaviour of health checking of the upstream servers + * `<duration>` - use a different duration for health checking, the default duration is 0.5s. + * `no_rec` - optional argument that sets the RecursionDesired-flag of the dns-query used in health checking to `false`. + The flag is default `true`. + * `domain FQDN` - set the domain name used for health checks to **FQDN**. + If not configured, the domain name used for health checks is `.`. +* `max_concurrent` **MAX** will limit the number of concurrent queries to **MAX**. Any new query that would + raise the number of concurrent queries above the **MAX** will result in a REFUSED response. This + response does not count as a health failure. When choosing a value for **MAX**, pick a number + at least greater than the expected *upstream query rate* * *latency* of the upstream servers. + As an upper bound for **MAX**, consider that each concurrent query will use about 2kb of memory. + +Also note the TLS config is "global" for the whole forwarding proxy if you need a different +`tls_servername` for different upstreams you're out of luck. + +On each endpoint, the timeouts for communication are set as follows: + +* The dial timeout by default is 30s, and can decrease automatically down to 1s based on early results. +* The read timeout is static at 2s. + +## Metadata + +The forward plugin will publish the following metadata, if the *metadata* +plugin is also enabled: + +* `forward/upstream`: the upstream used to forward the request + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metric are exported: + +* `coredns_forward_healthcheck_broken_total{}` - count of when all upstreams are unhealthy, + and we are randomly (this always uses the `random` policy) spraying to an upstream. +* `coredns_forward_max_concurrent_rejects_total{}` - count of queries rejected because the + number of concurrent queries were at maximum. +* `coredns_proxy_request_duration_seconds{proxy_name="forward", to, rcode}` - histogram per upstream, RCODE +* `coredns_proxy_healthcheck_failures_total{proxy_name="forward", to, rcode}`- count of failed health checks per upstream. +* `coredns_proxy_conn_cache_hits_total{proxy_name="forward", to, proto}`- count of connection cache hits per upstream and protocol. +* `coredns_proxy_conn_cache_misses_total{proxy_name="forward", to, proto}` - count of connection cache misses per upstream and protocol. + +Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE +from the upstream, `proto` is the transport protocol like `udp`, `tcp`, `tcp-tls`. + +The following metrics have recently been deprecated: +* `coredns_forward_healthcheck_failures_total{to, rcode}` + * Can be replaced with `coredns_proxy_healthcheck_failures_total{proxy_name="forward", to, rcode}` +* `coredns_forward_requests_total{to}` + * Can be replaced with `sum(coredns_proxy_request_duration_seconds_count{proxy_name="forward", to})` +* `coredns_forward_responses_total{to, rcode}` + * Can be replaced with `coredns_proxy_request_duration_seconds_count{proxy_name="forward", to, rcode}` +* `coredns_forward_request_duration_seconds{to, rcode}` + * Can be replaced with `coredns_proxy_request_duration_seconds{proxy_name="forward", to, rcode}` + +## Examples + +Proxy all requests within `example.org.` to a nameserver running on a different port: + +~~~ corefile +example.org { + forward . 127.0.0.1:9005 +} +~~~ + +Send all requests within `lab.example.local.` to `10.20.0.1`, all requests within `example.local.` (and not in +`lab.example.local.`) to `10.0.0.1`, all others requests to the servers defined in `/etc/resolv.conf`, and +caches results. Note that a CoreDNS server configured with multiple _forward_ plugins in a server block will evaluate those +forward plugins in the order they are listed when serving a request. Therefore, subdomains should be +placed before parent domains otherwise subdomain requests will be forwarded to the parent domain's upstream. +Accordingly, in this example `lab.example.local` is before `example.local`, and `example.local` is before `.`. + +~~~ corefile +. { + cache + forward lab.example.local 10.20.0.1 + forward example.local 10.0.0.1 + forward . /etc/resolv.conf +} +~~~ + +The example above is almost equivalent to the following example, except that example below defines three separate plugin +chains (and thus 3 separate instances of _cache_). + +~~~ corefile +lab.example.local { + cache + forward . 10.20.0.1 +} +example.local { + cache + forward . 10.0.0.1 +} +. { + cache + forward . /etc/resolv.conf +} +~~~ + +Load balance all requests between three resolvers, one of which has a IPv6 address. + +~~~ corefile +. { + forward . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53 +} +~~~ + +Forward everything except requests to `example.org` + +~~~ corefile +. { + forward . 10.0.0.10:1234 { + except example.org + } +} +~~~ + +Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers: + +~~~ corefile +. { + forward . /etc/resolv.conf { + except example.org + } +} +~~~ + +Proxy all requests to 9.9.9.9 using the DNS-over-TLS (DoT) protocol, and cache every answer for up to 30 +seconds. Note the `tls_servername` is mandatory if you want a working setup, as 9.9.9.9 can't be +used in the TLS negotiation. Also set the health check duration to 5s to not completely swamp the +service with health checks. + +~~~ corefile +. { + forward . tls://9.9.9.9 { + tls_servername dns.quad9.net + health_check 5s + } + cache 30 +} +~~~ + +Or configure other domain name for health check requests + +~~~ corefile +. { + forward . tls://9.9.9.9 { + tls_servername dns.quad9.net + health_check 5s domain example.org + } + cache 30 +} +~~~ + +Or with multiple upstreams from the same provider + +~~~ corefile +. { + forward . tls://1.1.1.1 tls://1.0.0.1 { + tls_servername cloudflare-dns.com + health_check 5s + } + cache 30 +} +~~~ + +Or when you have multiple DoT upstreams with different `tls_servername`s, you can do the following: + +~~~ corefile +. { + forward . 127.0.0.1:5301 127.0.0.1:5302 +} + +.:5301 { + forward . tls://8.8.8.8 tls://8.8.4.4 { + tls_servername dns.google + } +} + +.:5302 { + forward . tls://1.1.1.1 tls://1.0.0.1 { + tls_servername cloudflare-dns.com + } +} +~~~ + +## See Also + +[RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS. diff --git a/plugin/forward/dnstap.go b/plugin/forward/dnstap.go new file mode 100644 index 0000000..8195bb4 --- /dev/null +++ b/plugin/forward/dnstap.go @@ -0,0 +1,66 @@ +package forward + +import ( + "context" + "net" + "strconv" + "time" + + "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/request" + + tap "github.com/dnstap/golang-dnstap" + "github.com/miekg/dns" +) + +// toDnstap will send the forward and received message to the dnstap plugin. +func toDnstap(ctx context.Context, f *Forward, host string, state request.Request, opts proxy.Options, reply *dns.Msg, start time.Time) { + h, p, _ := net.SplitHostPort(host) // this is preparsed and can't err here + port, _ := strconv.ParseUint(p, 10, 32) // same here + ip := net.ParseIP(h) + + var ta net.Addr = &net.UDPAddr{IP: ip, Port: int(port)} + t := state.Proto() + switch { + case opts.ForceTCP: + t = "tcp" + case opts.PreferUDP: + t = "udp" + } + + if t == "tcp" { + ta = &net.TCPAddr{IP: ip, Port: int(port)} + } + + for _, t := range f.tapPlugins { + // Query + q := new(tap.Message) + msg.SetQueryTime(q, start) + // Forwarder dnstap messages are from the perspective of the downstream server + // (upstream is the forward server) + msg.SetQueryAddress(q, state.W.RemoteAddr()) + msg.SetResponseAddress(q, ta) + if t.IncludeRawMessage { + buf, _ := state.Req.Pack() + q.QueryMessage = buf + } + msg.SetType(q, tap.Message_FORWARDER_QUERY) + t.TapMessageWithMetadata(ctx, q, state) + + // Response + if reply != nil { + r := new(tap.Message) + if t.IncludeRawMessage { + buf, _ := reply.Pack() + r.ResponseMessage = buf + } + msg.SetQueryTime(r, start) + msg.SetQueryAddress(r, state.W.RemoteAddr()) + msg.SetResponseAddress(r, ta) + msg.SetResponseTime(r, time.Now()) + msg.SetType(r, tap.Message_FORWARDER_RESPONSE) + t.TapMessageWithMetadata(ctx, r, state) + } + } +} diff --git a/plugin/forward/forward.go b/plugin/forward/forward.go new file mode 100644 index 0000000..e53d74a --- /dev/null +++ b/plugin/forward/forward.go @@ -0,0 +1,259 @@ +// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package forward + +import ( + "context" + "crypto/tls" + "errors" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/debug" + "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/metadata" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" + otext "github.com/opentracing/opentracing-go/ext" +) + +var log = clog.NewWithPlugin("forward") + +const ( + defaultExpire = 10 * time.Second + hcInterval = 500 * time.Millisecond +) + +// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list +// of proxies each representing one upstream proxy. +type Forward struct { + concurrent int64 // atomic counters need to be first in struct for proper alignment + + proxies []*proxy.Proxy + p Policy + hcInterval time.Duration + + from string + ignored []string + + tlsConfig *tls.Config + tlsServerName string + maxfails uint32 + expire time.Duration + maxConcurrent int64 + + opts proxy.Options // also here for testing + + // ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded + // the maximum allowed (maxConcurrent) + ErrLimitExceeded error + + tapPlugins []*dnstap.Dnstap // when dnstap plugins are loaded, we use to this to send messages out. + + Next plugin.Handler +} + +// New returns a new Forward. +func New() *Forward { + f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: proxy.Options{ForceTCP: false, PreferUDP: false, HCRecursionDesired: true, HCDomain: "."}} + return f +} + +// SetProxy appends p to the proxy list and starts healthchecking. +func (f *Forward) SetProxy(p *proxy.Proxy) { + f.proxies = append(f.proxies, p) + p.Start(f.hcInterval) +} + +// SetTapPlugin appends one or more dnstap plugins to the tap plugin list. +func (f *Forward) SetTapPlugin(tapPlugin *dnstap.Dnstap) { + f.tapPlugins = append(f.tapPlugins, tapPlugin) + if nextPlugin, ok := tapPlugin.Next.(*dnstap.Dnstap); ok { + f.SetTapPlugin(nextPlugin) + } +} + +// Len returns the number of configured proxies. +func (f *Forward) Len() int { return len(f.proxies) } + +// Name implements plugin.Handler. +func (f *Forward) Name() string { return "forward" } + +// ServeDNS implements plugin.Handler. +func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + if !f.match(state) { + return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r) + } + + if f.maxConcurrent > 0 { + count := atomic.AddInt64(&(f.concurrent), 1) + defer atomic.AddInt64(&(f.concurrent), -1) + if count > f.maxConcurrent { + maxConcurrentRejectCount.Add(1) + return dns.RcodeRefused, f.ErrLimitExceeded + } + } + + fails := 0 + var span, child ot.Span + var upstreamErr error + span = ot.SpanFromContext(ctx) + i := 0 + list := f.List() + deadline := time.Now().Add(defaultTimeout) + start := time.Now() + for time.Now().Before(deadline) { + if i >= len(list) { + // reached the end of list, reset to begin + i = 0 + fails = 0 + } + + proxy := list[i] + i++ + if proxy.Down(f.maxfails) { + fails++ + if fails < len(f.proxies) { + continue + } + // All upstream proxies are dead, assume healthcheck is completely broken and randomly + // select an upstream to connect to. + r := new(random) + proxy = r.List(f.proxies)[0] + + healthcheckBrokenCount.Add(1) + } + + if span != nil { + child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context())) + otext.PeerAddress.Set(child, proxy.Addr()) + ctx = ot.ContextWithSpan(ctx, child) + } + + metadata.SetValueFunc(ctx, "forward/upstream", func() string { + return proxy.Addr() + }) + + var ( + ret *dns.Msg + err error + ) + opts := f.opts + + for { + ret, err = proxy.Connect(ctx, state, opts) + + if err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. + continue + } + // Retry with TCP if truncated and prefer_udp configured. + if ret != nil && ret.Truncated && !opts.ForceTCP && opts.PreferUDP { + opts.ForceTCP = true + continue + } + break + } + + if child != nil { + child.Finish() + } + + if len(f.tapPlugins) != 0 { + toDnstap(ctx, f, proxy.Addr(), state, opts, ret, start) + } + + upstreamErr = err + + if err != nil { + // Kick off health check to see if *our* upstream is broken. + if f.maxfails != 0 { + proxy.Healthcheck() + } + + if fails < len(f.proxies) { + continue + } + break + } + + // Check if the reply is correct; if not return FormErr. + if !state.Match(ret) { + debug.Hexdumpf(ret, "Wrong reply for id: %d, %s %d", ret.Id, state.QName(), state.QType()) + + formerr := new(dns.Msg) + formerr.SetRcode(state.Req, dns.RcodeFormatError) + w.WriteMsg(formerr) + return 0, nil + } + + w.WriteMsg(ret) + return 0, nil + } + + if upstreamErr != nil { + return dns.RcodeServerFailure, upstreamErr + } + + return dns.RcodeServerFailure, ErrNoHealthy +} + +func (f *Forward) match(state request.Request) bool { + if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) { + return false + } + + return true +} + +func (f *Forward) isAllowedDomain(name string) bool { + if dns.Name(name) == dns.Name(f.from) { + return true + } + + for _, ignore := range f.ignored { + if plugin.Name(ignore).Matches(name) { + return false + } + } + return true +} + +// ForceTCP returns if TCP is forced to be used even when the request comes in over UDP. +func (f *Forward) ForceTCP() bool { return f.opts.ForceTCP } + +// PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP. +func (f *Forward) PreferUDP() bool { return f.opts.PreferUDP } + +// List returns a set of proxies to be used for this client depending on the policy in f. +func (f *Forward) List() []*proxy.Proxy { return f.p.List(f.proxies) } + +var ( + // ErrNoHealthy means no healthy proxies left. + ErrNoHealthy = errors.New("no healthy proxies") + // ErrNoForward means no forwarder defined. + ErrNoForward = errors.New("no forwarder defined") + // ErrCachedClosed means cached connection was closed by peer. + ErrCachedClosed = errors.New("cached connection was closed by peer") +) + +// Options holds various Options that can be set. +type Options struct { + // ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag + ForceTCP bool + // PreferUDP use UDP protocol for upstream DNS request. + PreferUDP bool + // HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests + HCRecursionDesired bool + // HCDomain sets domain for Proxy healthcheck requests + HCDomain string +} + +var defaultTimeout = 5 * time.Second diff --git a/plugin/forward/forward_test.go b/plugin/forward/forward_test.go new file mode 100644 index 0000000..aca58cb --- /dev/null +++ b/plugin/forward/forward_test.go @@ -0,0 +1,76 @@ +package forward + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/caddy/caddyfile" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/plugin/pkg/transport" +) + +func TestList(t *testing.T) { + f := Forward{ + proxies: []*proxy.Proxy{ + proxy.NewProxy("TestList", "1.1.1.1:53", transport.DNS), + proxy.NewProxy("TestList", "2.2.2.2:53", transport.DNS), + proxy.NewProxy("TestList", "3.3.3.3:53", transport.DNS), + }, + p: &roundRobin{}, + } + + expect := []*proxy.Proxy{ + proxy.NewProxy("TestList", "2.2.2.2:53", transport.DNS), + proxy.NewProxy("TestList", "1.1.1.1:53", transport.DNS), + proxy.NewProxy("TestList", "3.3.3.3:53", transport.DNS), + } + got := f.List() + + if len(got) != len(expect) { + t.Fatalf("Expected: %v results, got: %v", len(expect), len(got)) + } + for i, p := range got { + if p.Addr() != expect[i].Addr() { + t.Fatalf("Expected proxy %v to be '%v', got: '%v'", i, expect[i].Addr(), p.Addr()) + } + } +} + +func TestSetTapPlugin(t *testing.T) { + input := `forward . 127.0.0.1 + dnstap /tmp/dnstap.sock full + dnstap tcp://example.com:6000 + ` + stanzas := strings.Split(input, "\n") + c := caddy.NewTestController("dns", strings.Join(stanzas[1:], "\n")) + dnstapSetup, err := caddy.DirectiveAction("dns", "dnstap") + if err != nil { + t.Fatal(err) + } + if err = dnstapSetup(c); err != nil { + t.Fatal(err) + } + c.Dispenser = caddyfile.NewDispenser("", strings.NewReader(stanzas[0])) + if err = setup(c); err != nil { + t.Fatal(err) + } + dnsserver.NewServer("", []*dnsserver.Config{dnsserver.GetConfig(c)}) + f, ok := dnsserver.GetConfig(c).Handler("forward").(*Forward) + if !ok { + t.Fatal("Expected a forward plugin") + } + tap, ok := dnsserver.GetConfig(c).Handler("dnstap").(*dnstap.Dnstap) + if !ok { + t.Fatal("Expected a dnstap plugin") + } + f.SetTapPlugin(tap) + if len(f.tapPlugins) != 2 { + t.Fatalf("Expected: 2 results, got: %v", len(f.tapPlugins)) + } + if f.tapPlugins[0] != tap || tap.Next != f.tapPlugins[1] { + t.Error("Unexpected order of dnstap plugins") + } +} diff --git a/plugin/forward/fuzz.go b/plugin/forward/fuzz.go new file mode 100644 index 0000000..ba6e915 --- /dev/null +++ b/plugin/forward/fuzz.go @@ -0,0 +1,35 @@ +//go:build gofuzz + +package forward + +import ( + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fuzz" + "github.com/coredns/coredns/plugin/pkg/proxy" + + "github.com/miekg/dns" +) + +var f *Forward + +// abuse init to setup an environment to test against. This start another server to that will +// reflect responses. +func init() { + f = New() + s := dnstest.NewServer(r{}.reflectHandler) + f.SetProxy(proxy.NewProxy(s.Addr, "tcp")) + f.SetProxy(proxy.NewProxy(s.Addr, "udp")) +} + +// Fuzz fuzzes forward. +func Fuzz(data []byte) int { + return fuzz.Do(f, data) +} + +type r struct{} + +func (r r) reflectHandler(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + w.WriteMsg(m) +} diff --git a/plugin/forward/health_test.go b/plugin/forward/health_test.go new file mode 100644 index 0000000..211a620 --- /dev/null +++ b/plugin/forward/health_test.go @@ -0,0 +1,279 @@ +package forward + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestHealth(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := proxy.NewProxy("TestHealth", s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthTCP(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := proxy.NewProxy("TestHealthTCP", s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetTCPTransport() + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{TCP: true}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthNoRecursion(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == false { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := proxy.NewProxy("TestHealthNoRecursion", s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetRecursionDesired(false) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1) + } +} + +func TestHealthTimeout(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." { + // health check, answer + atomic.AddUint32(&i, 1) + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + return + } + if atomic.LoadUint32(&q) == 0 { //drop only first query + atomic.AddUint32(&q, 1) + return + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := proxy.NewProxy("TestHealthTimeout", s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks to be %d, got %d", 1, i1) + } +} + +func TestHealthMaxFails(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + //,hcInterval = 10 * time.Millisecond + + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // timeout + }) + defer s.Close() + + p := proxy.NewProxy("TestHealthMaxFails", s.Addr, transport.DNS) + p.SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + f := New() + f.hcInterval = 10 * time.Millisecond + f.maxfails = 2 + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(100 * time.Millisecond) + fails := p.Fails() + if !p.Down(f.maxfails) { + t.Errorf("Expected Proxy fails to be greater than %d, got %d", f.maxfails, fails) + } +} + +func TestHealthNoMaxFails(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." { + // health check, answer + atomic.AddUint32(&i, 1) + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + } + }) + defer s.Close() + + p := proxy.NewProxy("TestHealthNoMaxFails", s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + f := New() + f.maxfails = 0 + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 0 { + t.Errorf("Expected number of health checks to be %d, got %d", 0, i1) + } +} + +func TestHealthDomain(t *testing.T) { + defaultTimeout = 10 * time.Millisecond + + hcDomain := "example.org." + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == hcDomain && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + p := proxy.NewProxy("TestHealthDomain", s.Addr, transport.DNS) + p.GetHealthchecker().SetReadTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetWriteTimeout(10 * time.Millisecond) + p.GetHealthchecker().SetDomain(hcDomain) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion(".", dns.TypeNS) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) + } +} diff --git a/plugin/forward/log_test.go b/plugin/forward/log_test.go new file mode 100644 index 0000000..a7f0a85 --- /dev/null +++ b/plugin/forward/log_test.go @@ -0,0 +1,5 @@ +package forward + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/forward/metrics.go b/plugin/forward/metrics.go new file mode 100644 index 0000000..246dc65 --- /dev/null +++ b/plugin/forward/metrics.go @@ -0,0 +1,25 @@ +package forward + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Variables declared for monitoring. +var ( + healthcheckBrokenCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "healthcheck_broken_total", + Help: "Counter of the number of complete failures of the healthchecks.", + }) + + maxConcurrentRejectCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "max_concurrent_rejects_total", + Help: "Counter of the number of queries rejected because the concurrent queries were at maximum.", + }) +) diff --git a/plugin/forward/policy.go b/plugin/forward/policy.go new file mode 100644 index 0000000..7bd1f31 --- /dev/null +++ b/plugin/forward/policy.go @@ -0,0 +1,69 @@ +package forward + +import ( + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/coredns/coredns/plugin/pkg/rand" +) + +// Policy defines a policy we use for selecting upstreams. +type Policy interface { + List([]*proxy.Proxy) []*proxy.Proxy + String() string +} + +// random is a policy that implements random upstream selection. +type random struct{} + +func (r *random) String() string { return "random" } + +func (r *random) List(p []*proxy.Proxy) []*proxy.Proxy { + switch len(p) { + case 1: + return p + case 2: + if rn.Int()%2 == 0 { + return []*proxy.Proxy{p[1], p[0]} // swap + } + return p + } + + perms := rn.Perm(len(p)) + rnd := make([]*proxy.Proxy, len(p)) + + for i, p1 := range perms { + rnd[i] = p[p1] + } + return rnd +} + +// roundRobin is a policy that selects hosts based on round robin ordering. +type roundRobin struct { + robin uint32 +} + +func (r *roundRobin) String() string { return "round_robin" } + +func (r *roundRobin) List(p []*proxy.Proxy) []*proxy.Proxy { + poolLen := uint32(len(p)) + i := atomic.AddUint32(&r.robin, 1) % poolLen + + robin := []*proxy.Proxy{p[i]} + robin = append(robin, p[:i]...) + robin = append(robin, p[i+1:]...) + + return robin +} + +// sequential is a policy that selects hosts based on sequential ordering. +type sequential struct{} + +func (r *sequential) String() string { return "sequential" } + +func (r *sequential) List(p []*proxy.Proxy) []*proxy.Proxy { + return p +} + +var rn = rand.New(time.Now().UnixNano()) diff --git a/plugin/forward/proxy_test.go b/plugin/forward/proxy_test.go new file mode 100644 index 0000000..daf5f96 --- /dev/null +++ b/plugin/forward/proxy_test.go @@ -0,0 +1,70 @@ +package forward + +import ( + "context" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestProxy(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . "+s.Addr) + fs, err := parseForward(c) + f := fs[0] + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + if x := rec.Msg.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +func TestProxyTLSFail(t *testing.T) { + // This is an udp/tcp test server, so we shouldn't reach it with TLS. + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . tls://"+s.Addr) + fs, err := parseForward(c) + f := fs[0] + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err == nil { + t.Fatal("Expected *not* to receive reply, but got one") + } +} diff --git a/plugin/forward/setup.go b/plugin/forward/setup.go new file mode 100644 index 0000000..5341b7e --- /dev/null +++ b/plugin/forward/setup.go @@ -0,0 +1,300 @@ +package forward + +import ( + "crypto/tls" + "errors" + "fmt" + "path/filepath" + "strconv" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/proxy" + pkgtls "github.com/coredns/coredns/plugin/pkg/tls" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +func init() { + plugin.Register("forward", setup) +} + +func setup(c *caddy.Controller) error { + fs, err := parseForward(c) + if err != nil { + return plugin.Error("forward", err) + } + for i := range fs { + f := fs[i] + if f.Len() > max { + return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len())) + } + + if i == len(fs)-1 { + // last forward: point next to next plugin + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + f.Next = next + return f + }) + } else { + // middle forward: point next to next forward + nextForward := fs[i+1] + dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler { + f.Next = nextForward + return f + }) + } + + c.OnStartup(func() error { + return f.OnStartup() + }) + c.OnStartup(func() error { + if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil { + f.SetTapPlugin(taph.(*dnstap.Dnstap)) + } + return nil + }) + + c.OnShutdown(func() error { + return f.OnShutdown() + }) + } + + return nil +} + +// OnStartup starts a goroutines for all proxies. +func (f *Forward) OnStartup() (err error) { + for _, p := range f.proxies { + p.Start(f.hcInterval) + } + return nil +} + +// OnShutdown stops all configured proxies. +func (f *Forward) OnShutdown() error { + for _, p := range f.proxies { + p.Stop() + } + return nil +} + +func parseForward(c *caddy.Controller) ([]*Forward, error) { + var fs = []*Forward{} + for c.Next() { + f, err := parseStanza(c) + if err != nil { + return nil, err + } + fs = append(fs, f) + } + return fs, nil +} + +func parseStanza(c *caddy.Controller) (*Forward, error) { + f := New() + + if !c.Args(&f.from) { + return f, c.ArgErr() + } + origFrom := f.from + zones := plugin.Host(f.from).NormalizeExact() + if len(zones) == 0 { + return f, fmt.Errorf("unable to normalize '%s'", f.from) + } + f.from = zones[0] // there can only be one here, won't work with non-octet reverse + + if len(zones) > 1 { + log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", origFrom, f.from) + } + + to := c.RemainingArgs() + if len(to) == 0 { + return f, c.ArgErr() + } + + toHosts, err := parse.HostPortOrFile(to...) + if err != nil { + return f, err + } + + transports := make([]string, len(toHosts)) + allowedTrans := map[string]bool{"dns": true, "tls": true} + for i, host := range toHosts { + trans, h := parse.Transport(host) + + if !allowedTrans[trans] { + return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) + } + p := proxy.NewProxy("forward", h, trans) + f.proxies = append(f.proxies, p) + transports[i] = trans + } + + for c.NextBlock() { + if err := parseBlock(c, f); err != nil { + return f, err + } + } + + if f.tlsServerName != "" { + f.tlsConfig.ServerName = f.tlsServerName + } + + // Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake + // in upcoming connections to the same TLS server. + f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies)) + + for i := range f.proxies { + // Only set this for proxies that need it. + if transports[i] == transport.TLS { + f.proxies[i].SetTLSConfig(f.tlsConfig) + } + f.proxies[i].SetExpire(f.expire) + f.proxies[i].GetHealthchecker().SetRecursionDesired(f.opts.HCRecursionDesired) + // when TLS is used, checks are set to tcp-tls + if f.opts.ForceTCP && transports[i] != transport.TLS { + f.proxies[i].GetHealthchecker().SetTCPTransport() + } + f.proxies[i].GetHealthchecker().SetDomain(f.opts.HCDomain) + } + + return f, nil +} + +func parseBlock(c *caddy.Controller, f *Forward) error { + config := dnsserver.GetConfig(c) + switch c.Val() { + case "except": + ignore := c.RemainingArgs() + if len(ignore) == 0 { + return c.ArgErr() + } + for i := 0; i < len(ignore); i++ { + f.ignored = append(f.ignored, plugin.Host(ignore[i]).NormalizeExact()...) + } + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.ParseUint(c.Val(), 10, 32) + if err != nil { + return err + } + f.maxfails = uint32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("health_check can't be negative: %d", dur) + } + f.hcInterval = dur + f.opts.HCDomain = "." + + for c.NextArg() { + switch hcOpts := c.Val(); hcOpts { + case "no_rec": + f.opts.HCRecursionDesired = false + case "domain": + if !c.NextArg() { + return c.ArgErr() + } + hcDomain := c.Val() + if _, ok := dns.IsDomainName(hcDomain); !ok { + return fmt.Errorf("health_check: invalid domain name %s", hcDomain) + } + f.opts.HCDomain = plugin.Name(hcDomain).Normalize() + default: + return fmt.Errorf("health_check: unknown option %s", hcOpts) + } + } + + case "force_tcp": + if c.NextArg() { + return c.ArgErr() + } + f.opts.ForceTCP = true + case "prefer_udp": + if c.NextArg() { + return c.ArgErr() + } + f.opts.PreferUDP = true + case "tls": + args := c.RemainingArgs() + if len(args) > 3 { + return c.ArgErr() + } + + for i := range args { + if !filepath.IsAbs(args[i]) && config.Root != "" { + args[i] = filepath.Join(config.Root, args[i]) + } + } + tlsConfig, err := pkgtls.NewTLSConfigFromArgs(args...) + if err != nil { + return err + } + f.tlsConfig = tlsConfig + case "tls_servername": + if !c.NextArg() { + return c.ArgErr() + } + f.tlsServerName = c.Val() + case "expire": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("expire can't be negative: %s", dur) + } + f.expire = dur + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + switch x := c.Val(); x { + case "random": + f.p = &random{} + case "round_robin": + f.p = &roundRobin{} + case "sequential": + f.p = &sequential{} + default: + return c.Errf("unknown policy '%s'", x) + } + case "max_concurrent": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + if n < 0 { + return fmt.Errorf("max_concurrent can't be negative: %d", n) + } + f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val()) + f.maxConcurrent = int64(n) + + default: + return c.Errf("unknown property '%s'", c.Val()) + } + + return nil +} + +const max = 15 // Maximum number of upstreams. diff --git a/plugin/forward/setup_policy_test.go b/plugin/forward/setup_policy_test.go new file mode 100644 index 0000000..13466d7 --- /dev/null +++ b/plugin/forward/setup_policy_test.go @@ -0,0 +1,47 @@ +package forward + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupPolicy(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedPolicy string + expectedErr string + }{ + // positive + {"forward . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""}, + {"forward . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""}, + {"forward . 127.0.0.1 {\npolicy sequential\n}\n", false, "sequential", ""}, + // negative + {"forward . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && (len(fs) == 0 || fs[0].p.String() != test.expectedPolicy) { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, fs[0].p.String()) + } + } +} diff --git a/plugin/forward/setup_test.go b/plugin/forward/setup_test.go new file mode 100644 index 0000000..cf046b4 --- /dev/null +++ b/plugin/forward/setup_test.go @@ -0,0 +1,335 @@ +package forward + +import ( + "os" + "reflect" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin/pkg/proxy" + + "github.com/miekg/dns" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedFrom string + expectedIgnored []string + expectedFails uint32 + expectedOpts proxy.Options + expectedErr string + }{ + // positive + {"forward . 127.0.0.1", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org\n}\n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "example.org."}, ""}, + {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, proxy.Options{ForceTCP: true, HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, proxy.Options{PreferUDP: true, HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, proxy.Options{PreferUDP: true, ForceTCP: true, HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1:53", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1:8080", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . [::1]:53", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . [2003::1]:53", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward . 127.0.0.1 \n", false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {"forward 10.9.3.0/18 127.0.0.1", false, "0.9.10.in-addr.arpa.", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, ""}, + {`forward . ::1 + forward com ::2`, false, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "plugin"}, + // negative + {"forward . a27.0.0.1", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "not an IP"}, + {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unknown property"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, "", nil, 0, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "Wrong argument count or unexpected line ending after 'domain'"}, + {"forward . https://127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "'https' is not supported as a destination protocol in forward: https://127.0.0.1"}, + {"forward xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 127.0.0.1 \n", true, ".", nil, 2, proxy.Options{HCRecursionDesired: true, HCDomain: "."}, "unable to normalize 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Fatalf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr { + f := fs[0] + if f.from != test.expectedFrom { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from) + } + if test.expectedIgnored != nil { + if !reflect.DeepEqual(f.ignored, test.expectedIgnored) { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored) + } + } + if f.maxfails != test.expectedFails { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails) + } + if f.opts != test.expectedOpts { + t.Errorf("Test %d: expected: %v, got: %v", i, test.expectedOpts, f.opts) + } + } + } +} + +func TestSetupTLS(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedServerName string + expectedErr string + }{ + // positive + {`forward . tls://127.0.0.1 { + tls_servername dns + }`, false, "dns", ""}, + {`forward . 127.0.0.1 { + tls_servername dns + }`, false, "", ""}, + {`forward . 127.0.0.1 { + tls + }`, false, "", ""}, + {`forward . tls://127.0.0.1`, false, "", ""}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + f := fs[0] + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) + } + + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].GetHealthchecker().GetTLSConfig().ServerName) + } + } +} + +func TestSetupResolvconf(t *testing.T) { + const resolv = "resolv.conf" + if err := os.WriteFile(resolv, + []byte(`nameserver 10.10.255.252 +nameserver 10.10.255.253`), 0666); err != nil { + t.Fatalf("Failed to write resolv.conf file: %s", err) + } + defer os.Remove(resolv) + + tests := []struct { + input string + shouldErr bool + expectedErr string + expectedNames []string + }{ + // pass + {`forward . ` + resolv, false, "", []string{"10.10.255.252:53", "10.10.255.253:53"}}, + // fail + {`forward . /dev/null`, true, "no nameservers", nil}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + continue + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + + f := fs[0] + for j, n := range test.expectedNames { + addr := f.proxies[j].Addr() + if n != addr { + t.Errorf("Test %d, expected %q, got %q", j, n, addr) + } + } + + for _, p := range f.proxies { + p.Healthcheck() // this should almost always err, we don't care it shouldn't crash + } + } +} + +func TestSetupMaxConcurrent(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedVal int64 + expectedErr string + }{ + // positive + {"forward . 127.0.0.1 {\nmax_concurrent 1000\n}\n", false, 1000, ""}, + // negative + {"forward . 127.0.0.1 {\nmax_concurrent many\n}\n", true, 0, "invalid"}, + {"forward . 127.0.0.1 {\nmax_concurrent -4\n}\n", true, 0, "negative"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + f := fs[0] + if f.maxConcurrent != test.expectedVal { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedVal, f.maxConcurrent) + } + } +} + +func TestSetupHealthCheck(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedRecVal bool + expectedDomain string + expectedErr string + }{ + // positive + {"forward . 127.0.0.1\n", false, true, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s\n}\n", false, true, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s no_rec\n}\n", false, false, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s no_rec domain example.org\n}\n", false, false, "example.org.", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org\n}\n", false, true, "example.org.", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain .\n}\n", false, true, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org.\n}\n", false, true, "example.org.", ""}, + // negative + {"forward . 127.0.0.1 {\nhealth_check no_rec\n}\n", true, true, ".", "time: invalid duration"}, + {"forward . 127.0.0.1 {\nhealth_check domain example.org\n}\n", true, true, "example.org", "time: invalid duration"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s rec\n}\n", true, true, ".", "health_check: unknown option rec"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, true, ".", "Wrong argument count or unexpected line ending after 'domain'"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example..org\n}\n", true, true, ".", "health_check: invalid domain name"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + + f := fs[0] + if f.opts.HCRecursionDesired != test.expectedRecVal || f.proxies[0].GetHealthchecker().GetRecursionDesired() != test.expectedRecVal || + f.opts.HCDomain != test.expectedDomain || f.proxies[0].GetHealthchecker().GetDomain() != test.expectedDomain || !dns.IsFqdn(f.proxies[0].GetHealthchecker().GetDomain()) { + t.Errorf("Test %d: expectedRec: %v, got: %v. expectedDomain: %s, got: %s. ", i, test.expectedRecVal, f.opts.HCRecursionDesired, test.expectedDomain, f.opts.HCDomain) + } + } +} + +func TestMultiForward(t *testing.T) { + input := ` + forward 1st.example.org 10.0.0.1 + forward 2nd.example.org 10.0.0.2 + forward 3rd.example.org 10.0.0.3 + ` + + c := caddy.NewTestController("dns", input) + setup(c) + dnsserver.NewServer("", []*dnsserver.Config{dnsserver.GetConfig(c)}) + + handlers := dnsserver.GetConfig(c).Handlers() + f1, ok := handlers[0].(*Forward) + if !ok { + t.Fatalf("expected first plugin to be Forward, got %v", reflect.TypeOf(f1.Next)) + } + + if f1.from != "1st.example.org." { + t.Errorf("expected first forward from \"1st.example.org.\", got %q", f1.from) + } + if f1.Next == nil { + t.Fatal("expected first forward to point to next forward instance, not nil") + } + + f2, ok := f1.Next.(*Forward) + if !ok { + t.Fatalf("expected second plugin to be Forward, got %v", reflect.TypeOf(f1.Next)) + } + if f2.from != "2nd.example.org." { + t.Errorf("expected second forward from \"2nd.example.org.\", got %q", f2.from) + } + if f2.Next == nil { + t.Fatal("expected second forward to point to third forward instance, got nil") + } + + f3, ok := f2.Next.(*Forward) + if !ok { + t.Fatalf("expected third plugin to be Forward, got %v", reflect.TypeOf(f2.Next)) + } + if f3.from != "3rd.example.org." { + t.Errorf("expected third forward from \"3rd.example.org.\", got %q", f3.from) + } + if f3.Next != nil { + t.Error("expected third plugin to be last, but Next is not nil") + } +} diff --git a/plugin/geoip/README.md b/plugin/geoip/README.md new file mode 100644 index 0000000..febad8a --- /dev/null +++ b/plugin/geoip/README.md @@ -0,0 +1,117 @@ +# geoip + +## Name + +*geoip* - Lookup maxmind geoip2 databases using the client IP, then add associated geoip data to the context request. + +## Description + +The *geoip* plugin add geo location data associated with the client IP, it allows you to configure a [geoIP2 maxmind database](https://dev.maxmind.com/geoip/docs/databases) to add the geo location data associated with the IP address. + +The data is added leveraging the *metadata* plugin, values can then be retrieved using it as well, for example: + +```go +import ( + "strconv" + "github.com/coredns/coredns/plugin/metadata" +) +// ... +if getLongitude := metadata.ValueFunc(ctx, "geoip/longitude"); getLongitude != nil { + if longitude, err := strconv.ParseFloat(getLongitude(), 64); err == nil { + // Do something useful with longitude. + } +} else { + // The metadata label geoip/longitude for some reason, was not set. +} +// ... +``` + +## Databases + +The supported databases use city schema such as `City` and `Enterprise`. Other databases types with different schemas are not supported yet. + +You can download a [free and public City database](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data). + +## Syntax + +```text +geoip [DBFILE] +``` + +or + +```text +geoip [DBFILE] { + [edns-subnet] +} +``` + +* **DBFILE** the mmdb database file path. We recommend updating your mmdb database periodically for more accurate results. +* `edns-subnet`: Optional. Use [EDNS0 subnet](https://en.wikipedia.org/wiki/EDNS_Client_Subnet) (if present) for Geo IP instead of the source IP of the DNS request. This helps identifying the closest source IP address through intermediary DNS resolvers, and it also makes GeoIP testing easy: `dig +subnet=1.2.3.4 @dns-server.example.com www.geo-aware.com`. + + **NOTE:** due to security reasons, recursive DNS resolvers may mask a few bits off of the clients' IP address, which can cause inaccuracies in GeoIP resolution. + + There is no defined mask size in the standards, but there are examples: [RFC 7871's example](https://datatracker.ietf.org/doc/html/rfc7871#section-13) conceals the last 72 bits of an IPv6 source address, and NS1 Help Center [mentions](https://help.ns1.com/hc/en-us/articles/360020256573-About-the-EDNS-Client-Subnet-ECS-DNS-extension) that ECS-enabled DNS resolvers send only the first three octets (eg. /24) of the source IPv4 address. + +## Examples + +The following configuration configures the `City` database, and looks up geolocation based on EDNS0 subnet if present. + +```txt +. { + geoip /opt/geoip2/db/GeoLite2-City.mmdb { + edns-subnet + } + metadata # Note that metadata plugin must be enabled as well. +} +``` + +The *view* plugin can use *geoip* metadata as selection criteria to provide GSLB functionality. +In this example, clients from the city "Exampleshire" will receive answers for `example.com` from the zone defined in +`example.com.exampleshire-db`. All other clients will receive answers from the zone defined in `example.com.db`. +Note that the order of the two `example.com` server blocks below is important; the default viewless server block +must be last. + +```txt +example.com { + view exampleshire { + expr metadata('geoip/city/name') == 'Exampleshire' + } + geoip /opt/geoip2/db/GeoLite2-City.mmdb + metadata + file example.com.exampleshire-db +} + +example.com { + file example.com.db +} +``` + +## Metadata Labels + +A limited set of fields will be exported as labels, all values are stored using strings **regardless of their underlying value type**, and therefore you may have to convert it back to its original type, note that numeric values are always represented in base 10. + +| Label | Type | Example | Description +| :----------------------------------- | :-------- | :-------------- | :------------------ +| `geoip/city/name` | `string` | `Cambridge` | Then city name in English language. +| `geoip/country/code` | `string` | `GB` | Country [ISO 3166-1](https://en.wikipedia.org/wiki/ISO_3166-1) code. +| `geoip/country/name` | `string` | `United Kingdom` | The country name in English language. +| `geoip/country/is_in_european_union` | `bool` | `false` | Either `true` or `false`. +| `geoip/continent/code` | `string` | `EU` | See [Continent codes](#ContinentCodes). +| `geoip/continent/name` | `string` | `Europe` | The continent name in English language. +| `geoip/latitude` | `float64` | `52.2242` | Base 10, max available precision. +| `geoip/longitude` | `float64` | `0.1315` | Base 10, max available precision. +| `geoip/timezone` | `string` | `Europe/London` | The timezone. +| `geoip/postalcode` | `string` | `CB4` | The postal code. + +## Continent Codes + +| Value | Continent (EN) | +| :---- | :------------- | +| AF | Africa | +| AN | Antarctica | +| AS | Asia | +| EU | Europe | +| NA | North America | +| OC | Oceania | +| SA | South America | diff --git a/plugin/geoip/city.go b/plugin/geoip/city.go new file mode 100644 index 0000000..2e5d9f7 --- /dev/null +++ b/plugin/geoip/city.go @@ -0,0 +1,58 @@ +package geoip + +import ( + "context" + "strconv" + + "github.com/coredns/coredns/plugin/metadata" + + "github.com/oschwald/geoip2-golang" +) + +const defaultLang = "en" + +func (g GeoIP) setCityMetadata(ctx context.Context, data *geoip2.City) { + // Set labels for city, country and continent names. + cityName := data.City.Names[defaultLang] + metadata.SetValueFunc(ctx, pluginName+"/city/name", func() string { + return cityName + }) + countryName := data.Country.Names[defaultLang] + metadata.SetValueFunc(ctx, pluginName+"/country/name", func() string { + return countryName + }) + continentName := data.Continent.Names[defaultLang] + metadata.SetValueFunc(ctx, pluginName+"/continent/name", func() string { + return continentName + }) + + countryCode := data.Country.IsoCode + metadata.SetValueFunc(ctx, pluginName+"/country/code", func() string { + return countryCode + }) + isInEurope := strconv.FormatBool(data.Country.IsInEuropeanUnion) + metadata.SetValueFunc(ctx, pluginName+"/country/is_in_european_union", func() string { + return isInEurope + }) + continentCode := data.Continent.Code + metadata.SetValueFunc(ctx, pluginName+"/continent/code", func() string { + return continentCode + }) + + latitude := strconv.FormatFloat(data.Location.Latitude, 'f', -1, 64) + metadata.SetValueFunc(ctx, pluginName+"/latitude", func() string { + return latitude + }) + longitude := strconv.FormatFloat(data.Location.Longitude, 'f', -1, 64) + metadata.SetValueFunc(ctx, pluginName+"/longitude", func() string { + return longitude + }) + timeZone := data.Location.TimeZone + metadata.SetValueFunc(ctx, pluginName+"/timezone", func() string { + return timeZone + }) + postalCode := data.Postal.Code + metadata.SetValueFunc(ctx, pluginName+"/postalcode", func() string { + return postalCode + }) +} diff --git a/plugin/geoip/geoip.go b/plugin/geoip/geoip.go new file mode 100644 index 0000000..3451c82 --- /dev/null +++ b/plugin/geoip/geoip.go @@ -0,0 +1,107 @@ +// Package geoip implements a max mind database plugin. +package geoip + +import ( + "context" + "fmt" + "net" + "path/filepath" + + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "github.com/oschwald/geoip2-golang" +) + +var log = clog.NewWithPlugin(pluginName) + +// GeoIP is a plugin that add geo location data to the request context by looking up a maxmind +// geoIP2 database, and which data can be later consumed by other middlewares. +type GeoIP struct { + Next plugin.Handler + db db + edns0 bool +} + +type db struct { + *geoip2.Reader + // provides defines the schemas that can be obtained by querying this database, by using + // bitwise operations. + provides int +} + +const ( + city = 1 << iota +) + +var probingIP = net.ParseIP("127.0.0.1") + +func newGeoIP(dbPath string, edns0 bool) (*GeoIP, error) { + reader, err := geoip2.Open(dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open database file: %v", err) + } + db := db{Reader: reader} + schemas := []struct { + provides int + name string + validate func() error + }{ + {name: "city", provides: city, validate: func() error { _, err := reader.City(probingIP); return err }}, + } + // Query the database to figure out the database type. + for _, schema := range schemas { + if err := schema.validate(); err != nil { + // If we get an InvalidMethodError then we know this database does not provide that schema. + if _, ok := err.(geoip2.InvalidMethodError); !ok { + return nil, fmt.Errorf("unexpected failure looking up database %q schema %q: %v", filepath.Base(dbPath), schema.name, err) + } + } else { + db.provides |= schema.provides + } + } + + if db.provides&city == 0 { + return nil, fmt.Errorf("database does not provide city schema") + } + + return &GeoIP{db: db, edns0: edns0}, nil +} + +// ServeDNS implements the plugin.Handler interface. +func (g GeoIP) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return plugin.NextOrFailure(pluginName, g.Next, ctx, w, r) +} + +// Metadata implements the metadata.Provider Interface in the metadata plugin, and is used to store +// the data associated with the source IP of every request. +func (g GeoIP) Metadata(ctx context.Context, state request.Request) context.Context { + srcIP := net.ParseIP(state.IP()) + + if g.edns0 { + if o := state.Req.IsEdns0(); o != nil { + for _, s := range o.Option { + if e, ok := s.(*dns.EDNS0_SUBNET); ok { + srcIP = e.Address + break + } + } + } + } + + switch { + case g.db.provides&city == city: + data, err := g.db.City(srcIP) + if err != nil { + log.Debugf("Setting up metadata failed due to database lookup error: %v", err) + return ctx + } + g.setCityMetadata(ctx, data) + } + return ctx +} + +// Name implements the Handler interface. +func (g GeoIP) Name() string { return pluginName } diff --git a/plugin/geoip/geoip_test.go b/plugin/geoip/geoip_test.go new file mode 100644 index 0000000..b11fc8b --- /dev/null +++ b/plugin/geoip/geoip_test.go @@ -0,0 +1,90 @@ +package geoip + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestMetadata(t *testing.T) { + tests := []struct { + label string + expectedValue string + }{ + {"geoip/city/name", "Cambridge"}, + + {"geoip/country/code", "GB"}, + {"geoip/country/name", "United Kingdom"}, + // is_in_european_union is set to true only to work around bool zero value, and test is really being set. + {"geoip/country/is_in_european_union", "true"}, + + {"geoip/continent/code", "EU"}, + {"geoip/continent/name", "Europe"}, + + {"geoip/latitude", "52.2242"}, + {"geoip/longitude", "0.1315"}, + {"geoip/timezone", "Europe/London"}, + {"geoip/postalcode", "CB4"}, + } + + knownIPAddr := "81.2.69.142" // This IP should be be part of the CDIR address range used to create the database fixtures. + for _, tc := range tests { + t.Run(fmt.Sprintf("%s/%s", tc.label, "direct"), func(t *testing.T) { + geoIP, err := newGeoIP(cityDBPath, false) + if err != nil { + t.Fatalf("unable to create geoIP plugin: %v", err) + } + state := request.Request{ + Req: new(dns.Msg), + W: &test.ResponseWriter{RemoteIP: knownIPAddr}, + } + testMetadata(t, state, geoIP, tc.label, tc.expectedValue) + }) + + t.Run(fmt.Sprintf("%s/%s", tc.label, "subnet"), func(t *testing.T) { + geoIP, err := newGeoIP(cityDBPath, true) + if err != nil { + t.Fatalf("unable to create geoIP plugin: %v", err) + } + state := request.Request{ + Req: new(dns.Msg), + W: &test.ResponseWriter{RemoteIP: "127.0.0.1"}, + } + state.Req.SetEdns0(4096, false) + if o := state.Req.IsEdns0(); o != nil { + addr := net.ParseIP(knownIPAddr) + o.Option = append(o.Option, (&dns.EDNS0_SUBNET{ + SourceNetmask: 32, + Address: addr, + })) + } + testMetadata(t, state, geoIP, tc.label, tc.expectedValue) + }) + } +} + +func testMetadata(t *testing.T, state request.Request, geoIP *GeoIP, label, expectedValue string) { + ctx := metadata.ContextWithMetadata(context.Background()) + rCtx := geoIP.Metadata(ctx, state) + if fmt.Sprintf("%p", ctx) != fmt.Sprintf("%p", rCtx) { + t.Errorf("returned context is expected to be the same one passed in the Metadata function") + } + + fn := metadata.ValueFunc(ctx, label) + if fn == nil { + t.Errorf("label %q not set in metadata plugin context", label) + return + } + value := fn() + if value != expectedValue { + t.Errorf("expected value for label %q should be %q, got %q instead", + label, expectedValue, value) + } +} diff --git a/plugin/geoip/setup.go b/plugin/geoip/setup.go new file mode 100644 index 0000000..7f6e16f --- /dev/null +++ b/plugin/geoip/setup.go @@ -0,0 +1,57 @@ +package geoip + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +const pluginName = "geoip" + +func init() { plugin.Register(pluginName, setup) } + +func setup(c *caddy.Controller) error { + geoip, err := geoipParse(c) + if err != nil { + return plugin.Error(pluginName, err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + geoip.Next = next + return geoip + }) + + return nil +} + +func geoipParse(c *caddy.Controller) (*GeoIP, error) { + var dbPath string + var edns0 bool + + for c.Next() { + if !c.NextArg() { + return nil, c.ArgErr() + } + if dbPath != "" { + return nil, c.Errf("configuring multiple databases is not supported") + } + dbPath = c.Val() + // There shouldn't be any more arguments. + if len(c.RemainingArgs()) != 0 { + return nil, c.ArgErr() + } + + for c.NextBlock() { + if c.Val() != "edns-subnet" { + return nil, c.Errf("unknown property %q", c.Val()) + } + edns0 = true + } + } + + geoIP, err := newGeoIP(dbPath, edns0) + if err != nil { + return geoIP, c.Err(err.Error()) + } + return geoIP, nil +} diff --git a/plugin/geoip/setup_test.go b/plugin/geoip/setup_test.go new file mode 100644 index 0000000..b9b0030 --- /dev/null +++ b/plugin/geoip/setup_test.go @@ -0,0 +1,110 @@ +package geoip + +import ( + "fmt" + "net" + "path/filepath" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +var ( + fixturesDir = "./testdata" + cityDBPath = filepath.Join(fixturesDir, "GeoLite2-City.mmdb") + unknownDBPath = filepath.Join(fixturesDir, "GeoLite2-UnknownDbType.mmdb") +) + +func TestProbingIP(t *testing.T) { + if probingIP == nil { + t.Fatalf("Invalid probing IP: %q", probingIP) + } +} + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", fmt.Sprintf("%s %s", pluginName, cityDBPath)) + plugins := dnsserver.GetConfig(c).Plugin + if len(plugins) != 0 { + t.Fatalf("Expected zero plugins after setup, %d found", len(plugins)) + } + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + plugins = dnsserver.GetConfig(c).Plugin + if len(plugins) != 1 { + t.Fatalf("Expected one plugin after setup, %d found", len(plugins)) + } +} + +func TestGeoIPParse(t *testing.T) { + c := caddy.NewTestController("dns", fmt.Sprintf("%s %s", pluginName, cityDBPath)) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + tests := []struct { + shouldErr bool + config string + expectedErr string + expectedDBType int + }{ + // Valid + {false, fmt.Sprintf("%s %s\n", pluginName, cityDBPath), "", city}, + {false, fmt.Sprintf("%s %s { edns-subnet }", pluginName, cityDBPath), "", city}, + + // Invalid + {true, pluginName, "Wrong argument count", 0}, + {true, fmt.Sprintf("%s %s {\n\tlanguages en fr es zh-CN\n}\n", pluginName, cityDBPath), "unknown property \"languages\"", 0}, + {true, fmt.Sprintf("%s %s\n%s %s\n", pluginName, cityDBPath, pluginName, cityDBPath), "configuring multiple databases is not supported", 0}, + {true, fmt.Sprintf("%s 1 2 3", pluginName), "Wrong argument count", 0}, + {true, fmt.Sprintf("%s { }", pluginName), "Error during parsing", 0}, + {true, fmt.Sprintf("%s /dbpath { city }", pluginName), "unknown property \"city\"", 0}, + {true, fmt.Sprintf("%s /invalidPath\n", pluginName), "failed to open database file: open /invalidPath: no such file or directory", 0}, + {true, fmt.Sprintf("%s %s\n", pluginName, unknownDBPath), "reader does not support the \"UnknownDbType\" database type", 0}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.config) + geoIP, err := geoipParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found none for input %s", i, test.config) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.config, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.config) + } + continue + } + + if geoIP.db.Reader == nil { + t.Errorf("Test %d: after parsing database reader should be initialized", i) + } + + if geoIP.db.provides&test.expectedDBType == 0 { + t.Errorf("Test %d: expected db type %d not found, database file provides %d", i, test.expectedDBType, geoIP.db.provides) + } + } + + // Set nil probingIP to test unexpected validate error() + defer func(ip net.IP) { probingIP = ip }(probingIP) + probingIP = nil + + c = caddy.NewTestController("dns", fmt.Sprintf("%s %s\n", pluginName, cityDBPath)) + _, err := geoipParse(c) + if err != nil { + expectedErr := "unexpected failure looking up database" + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to contain: %s", expectedErr) + } + } else { + t.Errorf("with a nil probingIP test is expected to fail") + } +} diff --git a/plugin/geoip/testdata/GeoLite2-City.mmdb b/plugin/geoip/testdata/GeoLite2-City.mmdb Binary files differnew file mode 100644 index 0000000..cd79ed9 --- /dev/null +++ b/plugin/geoip/testdata/GeoLite2-City.mmdb diff --git a/plugin/geoip/testdata/GeoLite2-UnknownDbType.mmdb b/plugin/geoip/testdata/GeoLite2-UnknownDbType.mmdb Binary files differnew file mode 100644 index 0000000..23efbf3 --- /dev/null +++ b/plugin/geoip/testdata/GeoLite2-UnknownDbType.mmdb diff --git a/plugin/geoip/testdata/README.md b/plugin/geoip/testdata/README.md new file mode 100644 index 0000000..2f6f884 --- /dev/null +++ b/plugin/geoip/testdata/README.md @@ -0,0 +1,112 @@ +# testdata +This directory contains mmdb database files used during the testing of this plugin. + +# Create mmdb database files +If you need to change them to add a new value, or field the best is to recreate them, the code snipped used to create them initially is provided next. + +```golang +package main + +import ( + "log" + "net" + "os" + + "github.com/maxmind/mmdbwriter" + "github.com/maxmind/mmdbwriter/inserter" + "github.com/maxmind/mmdbwriter/mmdbtype" +) + +const cdir = "81.2.69.142/32" + +// Create new mmdb database fixtures in this directory. +func main() { + createCityDB("GeoLite2-City.mmdb", "DBIP-City-Lite") + // Create unkwnon database type. + createCityDB("GeoLite2-UnknownDbType.mmdb", "UnknownDbType") +} + +func createCityDB(dbName, dbType string) { + // Load a database writer. + writer, err := mmdbwriter.New(mmdbwriter.Options{DatabaseType: dbType}) + if err != nil { + log.Fatal(err) + } + + // Define and insert the new data. + _, ip, err := net.ParseCIDR(cdir) + if err != nil { + log.Fatal(err) + } + + // TODO(snebel29): Find an alternative location in Europe Union. + record := mmdbtype.Map{ + "city": mmdbtype.Map{ + "geoname_id": mmdbtype.Uint64(2653941), + "names": mmdbtype.Map{ + "en": mmdbtype.String("Cambridge"), + "es": mmdbtype.String("Cambridge"), + }, + }, + "continent": mmdbtype.Map{ + "code": mmdbtype.String("EU"), + "geoname_id": mmdbtype.Uint64(6255148), + "names": mmdbtype.Map{ + "en": mmdbtype.String("Europe"), + "es": mmdbtype.String("Europa"), + }, + }, + "country": mmdbtype.Map{ + "iso_code": mmdbtype.String("GB"), + "geoname_id": mmdbtype.Uint64(2635167), + "names": mmdbtype.Map{ + "en": mmdbtype.String("United Kingdom"), + "es": mmdbtype.String("Reino Unido"), + }, + "is_in_european_union": mmdbtype.Bool(true), + }, + "location": mmdbtype.Map{ + "accuracy_radius": mmdbtype.Uint16(200), + "latitude": mmdbtype.Float64(52.2242), + "longitude": mmdbtype.Float64(0.1315), + "metro_code": mmdbtype.Uint64(0), + "time_zone": mmdbtype.String("Europe/London"), + }, + "postal": mmdbtype.Map{ + "code": mmdbtype.String("CB4"), + }, + "registered_country": mmdbtype.Map{ + "iso_code": mmdbtype.String("GB"), + "geoname_id": mmdbtype.Uint64(2635167), + "names": mmdbtype.Map{"en": mmdbtype.String("United Kingdom")}, + "is_in_european_union": mmdbtype.Bool(false), + }, + "subdivisions": mmdbtype.Slice{ + mmdbtype.Map{ + "iso_code": mmdbtype.String("ENG"), + "geoname_id": mmdbtype.Uint64(6269131), + "names": mmdbtype.Map{"en": mmdbtype.String("England")}, + }, + mmdbtype.Map{ + "iso_code": mmdbtype.String("CAM"), + "geoname_id": mmdbtype.Uint64(2653940), + "names": mmdbtype.Map{"en": mmdbtype.String("Cambridgeshire")}, + }, + }, + } + + if err := writer.InsertFunc(ip, inserter.TopLevelMergeWith(record)); err != nil { + log.Fatal(err) + } + + // Write the DB to the filesystem. + fh, err := os.Create(dbName) + if err != nil { + log.Fatal(err) + } + _, err = writer.WriteTo(fh) + if err != nil { + log.Fatal(err) + } +} +``` diff --git a/plugin/grpc/README.md b/plugin/grpc/README.md new file mode 100644 index 0000000..c2e8b34 --- /dev/null +++ b/plugin/grpc/README.md @@ -0,0 +1,143 @@ +# grpc + +## Name + +*grpc* - facilitates proxying DNS messages to upstream resolvers via gRPC protocol. + +## Description + +The *grpc* plugin supports gRPC and TLS. + +This plugin can only be used once per Server Block. + +## Syntax + +In its most basic form: + +~~~ +grpc FROM TO... +~~~ + +* **FROM** is the base domain to match for the request to be proxied. +* **TO...** are the destination endpoints to proxy to. The number of upstreams is + limited to 15. + +Multiple upstreams are randomized (see `policy`) on first use. When a proxy returns an error +the next upstream in the list is tried. + +Extra knobs are available with an expanded syntax: + +~~~ +grpc FROM TO... { + except IGNORED_NAMES... + tls CERT KEY CA + tls_servername NAME + policy random|round_robin|sequential +} +~~~ + +* **FROM** and **TO...** as above. +* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from proxying. + Requests that match none of these names will be passed through. +* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be + provided with the meaning as described below + + * `tls` - no client authentication is used, and the system CAs are used to verify the server certificate + * `tls` **CA** - no client authentication is used, and the file CA is used to verify the server certificate + * `tls` **CERT** **KEY** - client authentication is used with the specified cert/key pair. + The server certificate is verified with the system CAs + * `tls` **CERT** **KEY** **CA** - client authentication is used with the specified cert/key pair. + The server certificate is verified using the specified CA file + +* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 + needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario, + but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1 + (Cloudflare) will not work. +* `policy` specifies the policy to use for selecting upstream servers. The default is `random`. + +Also note the TLS config is "global" for the whole grpc proxy if you need a different +`tls-name` for different upstreams you're out of luck. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metric are exported: + +* `coredns_grpc_request_duration_seconds{to}` - duration per upstream interaction. +* `coredns_grpc_requests_total{to}` - query count per upstream. +* `coredns_grpc_responses_total{to, rcode}` - count of RCODEs per upstream. + and we are randomly (this always uses the `random` policy) spraying to an upstream. + +## Examples + +Proxy all requests within `example.org.` to a nameserver running on a different port: + +~~~ corefile +example.org { + grpc . 127.0.0.1:9005 +} +~~~ + +Load balance all requests between three resolvers, one of which has a IPv6 address. + +~~~ corefile +. { + grpc . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53 +} +~~~ + +Forward everything except requests to `example.org` + +~~~ corefile +. { + grpc . 10.0.0.10:1234 { + except example.org + } +} +~~~ + +Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers: + +~~~ corefile +. { + grpc . /etc/resolv.conf { + except example.org + } +} +~~~ + +Proxy all requests to 9.9.9.9 using the TLS protocol, and cache every answer for up to 30 +seconds. Note the `tls_servername` is mandatory if you want a working setup, as 9.9.9.9 can't be +used in the TLS negotiation. + +~~~ corefile +. { + grpc . 9.9.9.9 { + tls_servername dns.quad9.net + } + cache 30 +} +~~~ + +Or with multiple upstreams from the same provider + +~~~ corefile +. { + grpc . 1.1.1.1 1.0.0.1 { + tls_servername cloudflare-dns.com + } + cache 30 +} +~~~ + +Forward requests to a local upstream listening on a Unix domain socket. + +~~~ corefile +. { + grpc . unix:///path/to/grpc.sock +} +~~~ + +## Bugs + +The TLS config is global for the whole grpc proxy if you need a different `tls_servername` for +different upstreams you're out of luck.
\ No newline at end of file diff --git a/plugin/grpc/grpc.go b/plugin/grpc/grpc.go new file mode 100644 index 0000000..c2911ed --- /dev/null +++ b/plugin/grpc/grpc.go @@ -0,0 +1,143 @@ +package grpc + +import ( + "context" + "crypto/tls" + "errors" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/debug" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" +) + +// GRPC represents a plugin instance that can proxy requests to another (DNS) server via gRPC protocol. +// It has a list of proxies each representing one upstream proxy. +type GRPC struct { + proxies []*Proxy + p Policy + + from string + ignored []string + + tlsConfig *tls.Config + tlsServerName string + + Next plugin.Handler +} + +// ServeDNS implements the plugin.Handler interface. +func (g *GRPC) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + if !g.match(state) { + return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r) + } + + var ( + span, child ot.Span + ret *dns.Msg + err error + i int + ) + span = ot.SpanFromContext(ctx) + list := g.list() + deadline := time.Now().Add(defaultTimeout) + + for time.Now().Before(deadline) { + if i >= len(list) { + // reached the end of list without any answer + if ret != nil { + // write empty response and finish + w.WriteMsg(ret) + } + break + } + + proxy := list[i] + i++ + + if span != nil { + child = span.Tracer().StartSpan("query", ot.ChildOf(span.Context())) + ctx = ot.ContextWithSpan(ctx, child) + } + + ret, err = proxy.query(ctx, r) + if err != nil { + // Continue with the next proxy + continue + } + + if child != nil { + child.Finish() + } + + // Check if the reply is correct; if not return FormErr. + if !state.Match(ret) { + debug.Hexdumpf(ret, "Wrong reply for id: %d, %s %d", ret.Id, state.QName(), state.QType()) + + formerr := new(dns.Msg) + formerr.SetRcode(state.Req, dns.RcodeFormatError) + w.WriteMsg(formerr) + return 0, nil + } + + w.WriteMsg(ret) + return 0, nil + } + + // SERVFAIL if all healthy proxys returned errors. + if err != nil { + // just return the last error received + return dns.RcodeServerFailure, err + } + + return dns.RcodeServerFailure, ErrNoHealthy +} + +// NewGRPC returns a new GRPC. +func newGRPC() *GRPC { + g := &GRPC{ + p: new(random), + } + return g +} + +// Name implements the Handler interface. +func (g *GRPC) Name() string { return "grpc" } + +// Len returns the number of configured proxies. +func (g *GRPC) len() int { return len(g.proxies) } + +func (g *GRPC) match(state request.Request) bool { + if !plugin.Name(g.from).Matches(state.Name()) || !g.isAllowedDomain(state.Name()) { + return false + } + + return true +} + +func (g *GRPC) isAllowedDomain(name string) bool { + if dns.Name(name) == dns.Name(g.from) { + return true + } + + for _, ignore := range g.ignored { + if plugin.Name(ignore).Matches(name) { + return false + } + } + return true +} + +// List returns a set of proxies to be used for this client depending on the policy in p. +func (g *GRPC) list() []*Proxy { return g.p.List(g.proxies) } + +const defaultTimeout = 5 * time.Second + +var ( + // ErrNoHealthy means no healthy proxies left. + ErrNoHealthy = errors.New("no healthy gRPC proxies") +) diff --git a/plugin/grpc/grpc_test.go b/plugin/grpc/grpc_test.go new file mode 100644 index 0000000..06375ec --- /dev/null +++ b/plugin/grpc/grpc_test.go @@ -0,0 +1,75 @@ +package grpc + +import ( + "context" + "errors" + "testing" + + "github.com/coredns/coredns/pb" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestGRPC(t *testing.T) { + m := &dns.Msg{} + msg, err := m.Pack() + if err != nil { + t.Fatalf("Error packing response: %s", err.Error()) + } + dnsPacket := &pb.DnsPacket{Msg: msg} + tests := map[string]struct { + proxies []*Proxy + wantErr bool + }{ + "single_proxy_ok": { + proxies: []*Proxy{ + {client: &testServiceClient{dnsPacket: dnsPacket, err: nil}}, + }, + wantErr: false, + }, + "multiple_proxies_ok": { + proxies: []*Proxy{ + {client: &testServiceClient{dnsPacket: dnsPacket, err: nil}}, + {client: &testServiceClient{dnsPacket: dnsPacket, err: nil}}, + {client: &testServiceClient{dnsPacket: dnsPacket, err: nil}}, + }, + wantErr: false, + }, + "single_proxy_ko": { + proxies: []*Proxy{ + {client: &testServiceClient{dnsPacket: nil, err: errors.New("")}}, + }, + wantErr: true, + }, + "multiple_proxies_one_ko": { + proxies: []*Proxy{ + {client: &testServiceClient{dnsPacket: dnsPacket, err: nil}}, + {client: &testServiceClient{dnsPacket: nil, err: errors.New("")}}, + {client: &testServiceClient{dnsPacket: dnsPacket, err: nil}}, + }, + wantErr: false, + }, + "multiple_proxies_ko": { + proxies: []*Proxy{ + {client: &testServiceClient{dnsPacket: nil, err: errors.New("")}}, + {client: &testServiceClient{dnsPacket: nil, err: errors.New("")}}, + {client: &testServiceClient{dnsPacket: nil, err: errors.New("")}}, + }, + wantErr: true, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + g := newGRPC() + g.from = "." + g.proxies = tt.proxies + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if _, err := g.ServeDNS(context.TODO(), rec, m); err != nil && !tt.wantErr { + t.Fatal("Expected to receive reply, but didn't") + } + }) + } +} diff --git a/plugin/grpc/metrics.go b/plugin/grpc/metrics.go new file mode 100644 index 0000000..2857042 --- /dev/null +++ b/plugin/grpc/metrics.go @@ -0,0 +1,31 @@ +package grpc + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Variables declared for monitoring. +var ( + RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "grpc", + Name: "requests_total", + Help: "Counter of requests made per upstream.", + }, []string{"to"}) + RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "grpc", + Name: "responses_total", + Help: "Counter of requests made per upstream.", + }, []string{"rcode", "to"}) + RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "grpc", + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time each request took.", + }, []string{"to"}) +) diff --git a/plugin/grpc/policy.go b/plugin/grpc/policy.go new file mode 100644 index 0000000..686b2eb --- /dev/null +++ b/plugin/grpc/policy.go @@ -0,0 +1,68 @@ +package grpc + +import ( + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/rand" +) + +// Policy defines a policy we use for selecting upstreams. +type Policy interface { + List([]*Proxy) []*Proxy + String() string +} + +// random is a policy that implements random upstream selection. +type random struct{} + +func (r *random) String() string { return "random" } + +func (r *random) List(p []*Proxy) []*Proxy { + switch len(p) { + case 1: + return p + case 2: + if rn.Int()%2 == 0 { + return []*Proxy{p[1], p[0]} // swap + } + return p + } + + perms := rn.Perm(len(p)) + rnd := make([]*Proxy, len(p)) + + for i, p1 := range perms { + rnd[i] = p[p1] + } + return rnd +} + +// roundRobin is a policy that selects hosts based on round robin ordering. +type roundRobin struct { + robin uint32 +} + +func (r *roundRobin) String() string { return "round_robin" } + +func (r *roundRobin) List(p []*Proxy) []*Proxy { + poolLen := uint32(len(p)) + i := atomic.AddUint32(&r.robin, 1) % poolLen + + robin := []*Proxy{p[i]} + robin = append(robin, p[:i]...) + robin = append(robin, p[i+1:]...) + + return robin +} + +// sequential is a policy that selects hosts based on sequential ordering. +type sequential struct{} + +func (r *sequential) String() string { return "sequential" } + +func (r *sequential) List(p []*Proxy) []*Proxy { + return p +} + +var rn = rand.New(time.Now().UnixNano()) diff --git a/plugin/grpc/proxy.go b/plugin/grpc/proxy.go new file mode 100644 index 0000000..9a96e95 --- /dev/null +++ b/plugin/grpc/proxy.go @@ -0,0 +1,82 @@ +package grpc + +import ( + "context" + "crypto/tls" + "strconv" + "time" + + "github.com/coredns/coredns/pb" + + "github.com/miekg/dns" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" +) + +// Proxy defines an upstream host. +type Proxy struct { + addr string + + // connection + client pb.DnsServiceClient + dialOpts []grpc.DialOption +} + +// newProxy returns a new proxy. +func newProxy(addr string, tlsConfig *tls.Config) (*Proxy, error) { + p := &Proxy{ + addr: addr, + } + + if tlsConfig != nil { + p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + } else { + p.dialOpts = append(p.dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + conn, err := grpc.Dial(p.addr, p.dialOpts...) + if err != nil { + return nil, err + } + p.client = pb.NewDnsServiceClient(conn) + + return p, nil +} + +// query sends the request and waits for a response. +func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) { + start := time.Now() + + msg, err := req.Pack() + if err != nil { + return nil, err + } + + reply, err := p.client.Query(ctx, &pb.DnsPacket{Msg: msg}) + if err != nil { + // if not found message, return empty message with NXDomain code + if status.Code(err) == codes.NotFound { + m := new(dns.Msg).SetRcode(req, dns.RcodeNameError) + return m, nil + } + return nil, err + } + ret := new(dns.Msg) + if err := ret.Unpack(reply.Msg); err != nil { + return nil, err + } + + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) + } + + RequestCount.WithLabelValues(p.addr).Add(1) + RcodeCount.WithLabelValues(rc, p.addr).Add(1) + RequestDuration.WithLabelValues(p.addr).Observe(time.Since(start).Seconds()) + + return ret, nil +} diff --git a/plugin/grpc/proxy_test.go b/plugin/grpc/proxy_test.go new file mode 100644 index 0000000..2ca0b1b --- /dev/null +++ b/plugin/grpc/proxy_test.go @@ -0,0 +1,120 @@ +package grpc + +import ( + "context" + "errors" + "net" + "path" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/pb" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +func TestProxy(t *testing.T) { + tests := map[string]struct { + p *Proxy + res *dns.Msg + wantErr bool + }{ + "response_ok": { + p: &Proxy{}, + res: &dns.Msg{}, + wantErr: false, + }, + "nil_response": { + p: &Proxy{}, + res: nil, + wantErr: true, + }, + "tls": { + p: &Proxy{dialOpts: []grpc.DialOption{grpc.WithTransportCredentials(credentials.NewTLS(nil))}}, + res: &dns.Msg{}, + wantErr: false, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + var mock *testServiceClient + if tt.res != nil { + msg, err := tt.res.Pack() + if err != nil { + t.Fatalf("Error packing response: %s", err.Error()) + } + mock = &testServiceClient{&pb.DnsPacket{Msg: msg}, nil} + } else { + mock = &testServiceClient{nil, errors.New("server error")} + } + tt.p.client = mock + + _, err := tt.p.query(context.TODO(), new(dns.Msg)) + if err != nil && !tt.wantErr { + t.Fatalf("Error query(): %s", err.Error()) + } + }) + } +} + +type testServiceClient struct { + dnsPacket *pb.DnsPacket + err error +} + +func (m testServiceClient) Query(ctx context.Context, in *pb.DnsPacket, opts ...grpc.CallOption) (*pb.DnsPacket, error) { + return m.dnsPacket, m.err +} + +func TestProxyUnix(t *testing.T) { + tdir := t.TempDir() + + fd := path.Join(tdir, "test.grpc") + listener, err := net.Listen("unix", fd) + if err != nil { + t.Fatal("Failed to listen: ", err) + } + defer listener.Close() + + server := grpc.NewServer() + pb.RegisterDnsServiceServer(server, &grpcDnsServiceServer{}) + + go server.Serve(listener) + defer server.Stop() + + c := caddy.NewTestController("dns", "grpc . unix://"+fd) + g, err := parseGRPC(c) + + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := g.ServeDNS(context.TODO(), rec, m); err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + if x := rec.Msg.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +type grpcDnsServiceServer struct { + pb.UnimplementedDnsServiceServer +} + +func (*grpcDnsServiceServer) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) { + msg := &dns.Msg{} + msg.Unpack(in.GetMsg()) + answer := new(dns.Msg) + answer.Answer = append(answer.Answer, test.A("example.org. IN A 127.0.0.1")) + answer.SetRcode(msg, dns.RcodeSuccess) + buf, _ := answer.Pack() + return &pb.DnsPacket{Msg: buf}, nil +} diff --git a/plugin/grpc/setup.go b/plugin/grpc/setup.go new file mode 100644 index 0000000..d1c6762 --- /dev/null +++ b/plugin/grpc/setup.go @@ -0,0 +1,153 @@ +package grpc + +import ( + "crypto/tls" + "fmt" + "path/filepath" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/parse" + pkgtls "github.com/coredns/coredns/plugin/pkg/tls" +) + +func init() { plugin.Register("grpc", setup) } + +func setup(c *caddy.Controller) error { + g, err := parseGRPC(c) + if err != nil { + return plugin.Error("grpc", err) + } + + if g.len() > max { + return plugin.Error("grpc", fmt.Errorf("more than %d TOs configured: %d", max, g.len())) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + g.Next = next // Set the Next field, so the plugin chaining works. + return g + }) + + return nil +} + +func parseGRPC(c *caddy.Controller) (*GRPC, error) { + var ( + g *GRPC + err error + i int + ) + for c.Next() { + if i > 0 { + return nil, plugin.ErrOnce + } + i++ + g, err = parseStanza(c) + if err != nil { + return nil, err + } + } + return g, nil +} + +func parseStanza(c *caddy.Controller) (*GRPC, error) { + g := newGRPC() + + if !c.Args(&g.from) { + return g, c.ArgErr() + } + normalized := plugin.Host(g.from).NormalizeExact() + if len(normalized) == 0 { + return g, fmt.Errorf("unable to normalize '%s'", g.from) + } + g.from = normalized[0] // only the first is used. + + to := c.RemainingArgs() + if len(to) == 0 { + return g, c.ArgErr() + } + + toHosts, err := parse.HostPortOrFile(to...) + if err != nil { + return g, err + } + + for c.NextBlock() { + if err := parseBlock(c, g); err != nil { + return g, err + } + } + + if g.tlsServerName != "" { + if g.tlsConfig == nil { + g.tlsConfig = new(tls.Config) + } + g.tlsConfig.ServerName = g.tlsServerName + } + for _, host := range toHosts { + pr, err := newProxy(host, g.tlsConfig) + if err != nil { + return nil, err + } + g.proxies = append(g.proxies, pr) + } + + return g, nil +} + +func parseBlock(c *caddy.Controller, g *GRPC) error { + switch c.Val() { + case "except": + ignore := c.RemainingArgs() + if len(ignore) == 0 { + return c.ArgErr() + } + for i := 0; i < len(ignore); i++ { + g.ignored = append(g.ignored, plugin.Host(ignore[i]).NormalizeExact()...) + } + case "tls": + args := c.RemainingArgs() + if len(args) > 3 { + return c.ArgErr() + } + + for i := range args { + if !filepath.IsAbs(args[i]) && dnsserver.GetConfig(c).Root != "" { + args[i] = filepath.Join(dnsserver.GetConfig(c).Root, args[i]) + } + } + tlsConfig, err := pkgtls.NewTLSConfigFromArgs(args...) + if err != nil { + return err + } + g.tlsConfig = tlsConfig + case "tls_servername": + if !c.NextArg() { + return c.ArgErr() + } + g.tlsServerName = c.Val() + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + switch x := c.Val(); x { + case "random": + g.p = &random{} + case "round_robin": + g.p = &roundRobin{} + case "sequential": + g.p = &sequential{} + default: + return c.Errf("unknown policy '%s'", x) + } + default: + if c.Val() != "}" { + return c.Errf("unknown property '%s'", c.Val()) + } + } + + return nil +} + +const max = 15 // Maximum number of upstreams. diff --git a/plugin/grpc/setup_policy_test.go b/plugin/grpc/setup_policy_test.go new file mode 100644 index 0000000..c13339d --- /dev/null +++ b/plugin/grpc/setup_policy_test.go @@ -0,0 +1,47 @@ +package grpc + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupPolicy(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedPolicy string + expectedErr string + }{ + // positive + {"grpc . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""}, + {"grpc . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""}, + {"grpc . 127.0.0.1 {\npolicy sequential\n}\n", false, "sequential", ""}, + // negative + {"grpc . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + g, err := parseGRPC(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && g.p.String() != test.expectedPolicy { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, g.p.String()) + } + } +} diff --git a/plugin/grpc/setup_test.go b/plugin/grpc/setup_test.go new file mode 100644 index 0000000..f142099 --- /dev/null +++ b/plugin/grpc/setup_test.go @@ -0,0 +1,154 @@ +package grpc + +import ( + "os" + "reflect" + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedFrom string + expectedIgnored []string + expectedErr string + }{ + // positive + {"grpc . 127.0.0.1", false, ".", nil, ""}, + {"grpc . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, ""}, + {"grpc . 127.0.0.1", false, ".", nil, ""}, + {"grpc . 127.0.0.1:53", false, ".", nil, ""}, + {"grpc . 127.0.0.1:8080", false, ".", nil, ""}, + {"grpc . [::1]:53", false, ".", nil, ""}, + {"grpc . [2003::1]:53", false, ".", nil, ""}, + {"grpc . unix:///var/run/g.sock", false, ".", nil, ""}, + // negative + {"grpc . a27.0.0.1", true, "", nil, "not an IP"}, + {"grpc . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, "unknown property"}, + {`grpc . ::1 + grpc com ::2`, true, "", nil, "plugin"}, + {"grpc xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 127.0.0.1", true, "", nil, "unable to normalize 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"}, + } + + for i, test := range tests { + c := caddy.NewTestController("grpc", test.input) + g, err := parseGRPC(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && g.from != test.expectedFrom { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, g.from) + } + if !test.shouldErr && test.expectedIgnored != nil { + if !reflect.DeepEqual(g.ignored, test.expectedIgnored) { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, g.ignored) + } + } + } +} + +func TestSetupTLS(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedServerName string + expectedErr string + }{ + // positive + {`grpc . 127.0.0.1 { +tls_servername dns +}`, false, "dns", ""}, + {`grpc . 127.0.0.1 { +tls +}`, false, "", ""}, + {`grpc . 127.0.0.1`, false, "", ""}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + g, err := parseGRPC(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && test.expectedServerName != "" && g.tlsConfig != nil && test.expectedServerName != g.tlsConfig.ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, g.tlsConfig.ServerName) + } + } +} + +func TestSetupResolvconf(t *testing.T) { + const resolv = "resolv.conf" + if err := os.WriteFile(resolv, + []byte(`nameserver 10.10.255.252 +nameserver 10.10.255.253`), 0666); err != nil { + t.Fatalf("Failed to write resolv.conf file: %s", err) + } + defer os.Remove(resolv) + + tests := []struct { + input string + shouldErr bool + expectedErr string + expectedNames []string + }{ + // pass + {`grpc . ` + resolv, false, "", []string{"10.10.255.252:53", "10.10.255.253:53"}}, + } + + for i, test := range tests { + c := caddy.NewTestController("grpc", test.input) + f, err := parseGRPC(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + continue + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr { + for j, n := range test.expectedNames { + addr := f.proxies[j].addr + if n != addr { + t.Errorf("Test %d, expected %q, got %q", j, n, addr) + } + } + } + } +} diff --git a/plugin/header/README.md b/plugin/header/README.md new file mode 100644 index 0000000..9a855c7 --- /dev/null +++ b/plugin/header/README.md @@ -0,0 +1,63 @@ +# header + +## Name + +*header* - modifies the header for queries and responses. + +## Description + +*header* ensures that the flags are in the desired state for queries and responses. +The modifications are made transparently for the client and subsequent plugins. + +## Syntax + +~~~ +header { + [SELECTOR] ACTION FLAGS... + [SELECTOR] ACTION FLAGS... +} +~~~ + +* **SELECTOR** defines if the action should be applied on `query` or `response`. In future CoreDNS version the selector will be mandatory. For backwards compatibility the action will be applied on `response` if the selector is undefined. + +* **ACTION** defines the state for DNS message header flags. Actions are evaluated in the order they are defined so last one has the + most precedence. Allowed values are: + * `set` + * `clear` +* **FLAGS** are the DNS header flags that will be modified. Current supported flags include: + * `aa` - Authoritative(Answer) + * `ra` - RecursionAvailable + * `rd` - RecursionDesired + +## Examples + +Make sure recursive available `ra` flag is set in all the responses: + +~~~ corefile +. { + header { + response set ra + } +} +~~~ + +Make sure "recursion available" `ra` and "authoritative answer" `aa` flags are set and "recursion desired" is cleared in all responses: + +~~~ corefile +. { + header { + response set ra aa + response clear rd + } +} +~~~ + +Make sure "recursion desired" `rd` is set for all subsequent plugins:: + +~~~ corefile +. { + header { + query set rd + } +} +~~~ diff --git a/plugin/header/handler.go b/plugin/header/handler.go new file mode 100644 index 0000000..e11eb03 --- /dev/null +++ b/plugin/header/handler.go @@ -0,0 +1,27 @@ +package header + +import ( + "context" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +// Header modifies flags of dns.MsgHdr in queries and / or responses +type Header struct { + QueryRules []Rule + ResponseRules []Rule + Next plugin.Handler +} + +// ServeDNS implements the plugin.Handler interface. +func (h Header) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + applyRules(r, h.QueryRules) + + wr := ResponseHeaderWriter{ResponseWriter: w, Rules: h.ResponseRules} + return plugin.NextOrFailure(h.Name(), h.Next, ctx, &wr, r) +} + +// Name implements the plugin.Handler interface. +func (h Header) Name() string { return "header" } diff --git a/plugin/header/header.go b/plugin/header/header.go new file mode 100644 index 0000000..830587d --- /dev/null +++ b/plugin/header/header.go @@ -0,0 +1,95 @@ +package header + +import ( + "fmt" + "strings" + + clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" +) + +// Supported flags +const ( + authoritative = "aa" + recursionAvailable = "ra" + recursionDesired = "rd" +) + +var log = clog.NewWithPlugin("header") + +// ResponseHeaderWriter is a response writer that allows modifying dns.MsgHdr +type ResponseHeaderWriter struct { + dns.ResponseWriter + Rules []Rule +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (r *ResponseHeaderWriter) WriteMsg(res *dns.Msg) error { + applyRules(res, r.Rules) + return r.ResponseWriter.WriteMsg(res) +} + +// Write implements the dns.ResponseWriter interface. +func (r *ResponseHeaderWriter) Write(buf []byte) (int, error) { + log.Warning("ResponseHeaderWriter called with Write: not ensuring headers") + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +// Rule is used to set/clear Flag in dns.MsgHdr +type Rule struct { + Flag string + State bool +} + +func newRules(key string, args []string) ([]Rule, error) { + if key == "" { + return nil, fmt.Errorf("no flag action provided") + } + + if len(args) < 1 { + return nil, fmt.Errorf("invalid length for flags, at least one should be provided") + } + + var state bool + action := strings.ToLower(key) + switch action { + case "set": + state = true + case "clear": + state = false + default: + return nil, fmt.Errorf("unknown flag action=%s, should be set or clear", action) + } + + var rules []Rule + for _, arg := range args { + flag := strings.ToLower(arg) + switch flag { + case authoritative: + case recursionAvailable: + case recursionDesired: + default: + return nil, fmt.Errorf("unknown/unsupported flag=%s", flag) + } + rule := Rule{Flag: flag, State: state} + rules = append(rules, rule) + } + + return rules, nil +} + +func applyRules(res *dns.Msg, rules []Rule) { + // handle all supported flags + for _, rule := range rules { + switch rule.Flag { + case authoritative: + res.Authoritative = rule.State + case recursionAvailable: + res.RecursionAvailable = rule.State + case recursionDesired: + res.RecursionDesired = rule.State + } + } +} diff --git a/plugin/header/header_test.go b/plugin/header/header_test.go new file mode 100644 index 0000000..1182654 --- /dev/null +++ b/plugin/header/header_test.go @@ -0,0 +1,152 @@ +package header + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestHeaderResponseRules(t *testing.T) { + wr := dnstest.NewRecorder(&test.ResponseWriter{}) + next := plugin.HandlerFunc(func(ctx context.Context, writer dns.ResponseWriter, msg *dns.Msg) (int, error) { + writer.WriteMsg(msg) + return dns.RcodeSuccess, nil + }) + + tests := []struct { + handler plugin.Handler + got func(msg *dns.Msg) bool + expected bool + }{ + { + handler: Header{ + ResponseRules: []Rule{{Flag: recursionAvailable, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionAvailable + }, + expected: true, + }, + { + handler: Header{ + ResponseRules: []Rule{{Flag: recursionAvailable, State: false}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionAvailable + }, + expected: false, + }, + { + handler: Header{ + ResponseRules: []Rule{{Flag: recursionDesired, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionDesired + }, + expected: true, + }, + { + handler: Header{ + ResponseRules: []Rule{{Flag: authoritative, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.Authoritative + }, + expected: true, + }, + } + + for i, test := range tests { + m := new(dns.Msg) + + _, err := test.handler.ServeDNS(context.TODO(), wr, m) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + if test.got(m) != test.expected { + t.Errorf("Test %d: Expected flag state=%t, but got %t", i, test.expected, test.got(m)) + continue + } + } +} + +func TestHeaderQueryRules(t *testing.T) { + wr := dnstest.NewRecorder(&test.ResponseWriter{}) + next := plugin.HandlerFunc(func(ctx context.Context, writer dns.ResponseWriter, msg *dns.Msg) (int, error) { + writer.WriteMsg(msg) + return dns.RcodeSuccess, nil + }) + + tests := []struct { + handler plugin.Handler + got func(msg *dns.Msg) bool + expected bool + }{ + { + handler: Header{ + QueryRules: []Rule{{Flag: recursionAvailable, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionAvailable + }, + expected: true, + }, + { + handler: Header{ + QueryRules: []Rule{{Flag: recursionDesired, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionDesired + }, + expected: true, + }, + { + handler: Header{ + QueryRules: []Rule{{Flag: recursionDesired, State: false}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.RecursionDesired + }, + expected: false, + }, + { + handler: Header{ + QueryRules: []Rule{{Flag: authoritative, State: true}}, + Next: next, + }, + got: func(msg *dns.Msg) bool { + return msg.Authoritative + }, + expected: true, + }, + } + + for i, tc := range tests { + m := new(dns.Msg) + + _, err := tc.handler.ServeDNS(context.TODO(), wr, m) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + if tc.got(m) != tc.expected { + t.Errorf("Test %d: Expected flag state=%t, but got %t", i, tc.expected, tc.got(m)) + continue + } + } +} diff --git a/plugin/header/setup.go b/plugin/header/setup.go new file mode 100644 index 0000000..3d6facf --- /dev/null +++ b/plugin/header/setup.go @@ -0,0 +1,74 @@ +package header + +import ( + "fmt" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("header", setup) } + +func setup(c *caddy.Controller) error { + queryRules, responseRules, err := parse(c) + if err != nil { + return plugin.Error("header", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Header{ + QueryRules: queryRules, + ResponseRules: responseRules, + Next: next, + } + }) + + return nil +} + +func parse(c *caddy.Controller) ([]Rule, []Rule, error) { + for c.Next() { + var queryRules []Rule + var responseRules []Rule + + for c.NextBlock() { + selector := strings.ToLower(c.Val()) + + var action string + if selector == "set" || selector == "clear" { + log.Warningf("The selector for header rule in line %d isn't explicit defined. "+ + "Assume rule applies for selector 'response'. This syntax is deprecated. "+ + "In future versions of CoreDNS the selector must be explicit defined.", + c.Line()) + + action = selector + selector = "response" + } else if selector == "query" || selector == "response" { + if c.NextArg() { + action = c.Val() + } + } else { + return nil, nil, fmt.Errorf("setting up rule: invalid selector=%s should be query or response", selector) + } + + args := c.RemainingArgs() + rules, err := newRules(action, args) + if err != nil { + return nil, nil, fmt.Errorf("setting up rule: %w", err) + } + + if selector == "response" { + responseRules = append(responseRules, rules...) + } else { + queryRules = append(queryRules, rules...) + } + } + + if len(queryRules) > 0 || len(responseRules) > 0 { + return queryRules, responseRules, nil + } + } + return nil, nil, c.ArgErr() +} diff --git a/plugin/header/setup_test.go b/plugin/header/setup_test.go new file mode 100644 index 0000000..36b7995 --- /dev/null +++ b/plugin/header/setup_test.go @@ -0,0 +1,65 @@ +package header + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupHeader(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedErrContent string + }{ + {`header {}`, true, "Wrong argument count or unexpected line ending after"}, + {`header { + set +}`, true, "invalid length for flags, at least one should be provided"}, + {`header { + foo +}`, true, "invalid selector=foo should be query or response"}, + {`header { + query foo +}`, true, "invalid length for flags, at least one should be provided"}, + {`header { + query foo rd +}`, true, "unknown flag action=foo, should be set or clear"}, + {`header { + set ra +}`, false, ""}, + {`header { + clear ra + }`, false, ""}, + {`header { + query set rd + }`, false, ""}, + {`header { + response set aa + }`, false, ""}, + {`header { + set ra aa + clear rd +}`, false, ""}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + } +} diff --git a/plugin/health/README.md b/plugin/health/README.md new file mode 100644 index 0000000..b18d2ec --- /dev/null +++ b/plugin/health/README.md @@ -0,0 +1,80 @@ +# health + +## Name + +*health* - enables a health check endpoint. + +## Description + +Enabled process wide health endpoint. When CoreDNS is up and running this returns a 200 OK HTTP +status code. The health is exported, by default, on port 8080/health. + +## Syntax + +~~~ +health [ADDRESS] +~~~ + +Optionally takes an address; the default is `:8080`. The health path is fixed to `/health`. The +health endpoint returns a 200 response code and the word "OK" when this server is healthy. + +An extra option can be set with this extended syntax: + +~~~ +health [ADDRESS] { + lameduck DURATION +} +~~~ + +* Where `lameduck` will delay shutdown for **DURATION**. /health will still answer 200 OK. + Note: The *ready* plugin will not answer OK while CoreDNS is in lame duck mode prior to shutdown. + +If you have multiple Server Blocks, *health* can only be enabled in one of them (as it is process +wide). If you really need multiple endpoints, you must run health endpoints on different ports: + +~~~ corefile +com { + whoami + health :8080 +} + +net { + erratic + health :8081 +} +~~~ + +Doing this is supported but both endpoints ":8080" and ":8081" will export the exact same health. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metrics are exported: + + * `coredns_health_request_duration_seconds{}` - The *health* plugin performs a self health check + once per second on the `/health` endpoint. This metric is the duration to process that request. + As this is a local operation it should be fast. A (large) increase in this + duration indicates the CoreDNS process is having trouble keeping up with its query load. + * `coredns_health_request_failures_total{}` - The number of times the self health check failed. + +Note that these metrics *do not* have a `server` label, because being overloaded is a symptom of +the running process, *not* a specific server. + +## Examples + +Run another health endpoint on http://localhost:8091. + +~~~ corefile +. { + health localhost:8091 +} +~~~ + +Set a lame duck duration of 1 second: + +~~~ corefile +. { + health localhost:8092 { + lameduck 1s + } +} +~~~ diff --git a/plugin/health/health.go b/plugin/health/health.go new file mode 100644 index 0000000..980cf2b --- /dev/null +++ b/plugin/health/health.go @@ -0,0 +1,99 @@ +// Package health implements an HTTP handler that responds to health checks. +package health + +import ( + "context" + "io" + "net" + "net/http" + "net/url" + "time" + + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/reuseport" +) + +var log = clog.NewWithPlugin("health") + +// Health implements healthchecks by exporting a HTTP endpoint. +type health struct { + Addr string + lameduck time.Duration + healthURI *url.URL + + ln net.Listener + nlSetup bool + mux *http.ServeMux + + stop context.CancelFunc +} + +func (h *health) OnStartup() error { + if h.Addr == "" { + h.Addr = ":8080" + } + + var err error + h.healthURI, err = url.Parse("http://" + h.Addr) + if err != nil { + return err + } + + h.healthURI.Path = "/health" + if h.healthURI.Host == "" { + // while we can listen on multiple network interfaces, we need to pick one to poll + h.healthURI.Host = "localhost" + } + + ln, err := reuseport.Listen("tcp", h.Addr) + if err != nil { + return err + } + + h.ln = ln + h.mux = http.NewServeMux() + h.nlSetup = true + + h.mux.HandleFunc(h.healthURI.Path, func(w http.ResponseWriter, r *http.Request) { + // We're always healthy. + w.WriteHeader(http.StatusOK) + io.WriteString(w, http.StatusText(http.StatusOK)) + }) + + ctx := context.Background() + ctx, h.stop = context.WithCancel(ctx) + + go func() { http.Serve(h.ln, h.mux) }() + go func() { h.overloaded(ctx) }() + + return nil +} + +func (h *health) OnFinalShutdown() error { + if !h.nlSetup { + return nil + } + + if h.lameduck > 0 { + log.Infof("Going into lameduck mode for %s", h.lameduck) + time.Sleep(h.lameduck) + } + + h.stop() + + h.ln.Close() + h.nlSetup = false + return nil +} + +func (h *health) OnReload() error { + if !h.nlSetup { + return nil + } + + h.stop() + + h.ln.Close() + h.nlSetup = false + return nil +} diff --git a/plugin/health/health_test.go b/plugin/health/health_test.go new file mode 100644 index 0000000..c49a5d7 --- /dev/null +++ b/plugin/health/health_test.go @@ -0,0 +1,47 @@ +package health + +import ( + "fmt" + "io" + "net/http" + "testing" + "time" +) + +func TestHealth(t *testing.T) { + h := &health{Addr: ":0"} + + if err := h.OnStartup(); err != nil { + t.Fatalf("Unable to startup the health server: %v", err) + } + defer h.OnFinalShutdown() + + address := fmt.Sprintf("http://%s%s", h.ln.Addr().String(), "/health") + + response, err := http.Get(address) + if err != nil { + t.Fatalf("Unable to query %s: %v", address, err) + } + if response.StatusCode != http.StatusOK { + t.Errorf("Invalid status code: expecting '200', got '%d'", response.StatusCode) + } + content, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("Unable to get response body from %s: %v", address, err) + } + response.Body.Close() + + if string(content) != http.StatusText(http.StatusOK) { + t.Errorf("Invalid response body: expecting 'OK', got '%s'", string(content)) + } +} + +func TestHealthLameduck(t *testing.T) { + h := &health{Addr: ":0", lameduck: 250 * time.Millisecond} + + if err := h.OnStartup(); err != nil { + t.Fatalf("Unable to startup the health server: %v", err) + } + + h.OnFinalShutdown() +} diff --git a/plugin/health/log_test.go b/plugin/health/log_test.go new file mode 100644 index 0000000..7e6c97b --- /dev/null +++ b/plugin/health/log_test.go @@ -0,0 +1,5 @@ +package health + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/health/overloaded.go b/plugin/health/overloaded.go new file mode 100644 index 0000000..160f90f --- /dev/null +++ b/plugin/health/overloaded.go @@ -0,0 +1,83 @@ +package health + +import ( + "context" + "net" + "net/http" + "time" + + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// overloaded queries the health end point and updates a metrics showing how long it took. +func (h *health) overloaded(ctx context.Context) { + bypassProxy := &http.Transport{ + Proxy: nil, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + timeout := 3 * time.Second + client := http.Client{ + Timeout: timeout, + Transport: bypassProxy, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, h.healthURI.String(), nil) + tick := time.NewTicker(1 * time.Second) + defer tick.Stop() + + for { + select { + case <-tick.C: + start := time.Now() + resp, err := client.Do(req) + if err != nil && ctx.Err() == context.Canceled { + // request was cancelled by parent goroutine + return + } + if err != nil { + HealthDuration.Observe(time.Since(start).Seconds()) + HealthFailures.Inc() + log.Warningf("Local health request to %q failed: %s", req.URL.String(), err) + continue + } + resp.Body.Close() + elapsed := time.Since(start) + HealthDuration.Observe(elapsed.Seconds()) + if elapsed > time.Second { // 1s is pretty random, but a *local* scrape taking that long isn't good + log.Warningf("Local health request to %q took more than 1s: %s", req.URL.String(), elapsed) + } + + case <-ctx.Done(): + return + } + } +} + +var ( + // HealthDuration is the metric used for exporting how fast we can retrieve the /health endpoint. + HealthDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "health", + Name: "request_duration_seconds", + Buckets: plugin.SlimTimeBuckets, + Help: "Histogram of the time (in seconds) each request took.", + }) + // HealthFailures is the metric used to count how many times the health request failed + HealthFailures = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "health", + Name: "request_failures_total", + Help: "The number of times the health check failed.", + }) +) diff --git a/plugin/health/overloaded_test.go b/plugin/health/overloaded_test.go new file mode 100644 index 0000000..da40a4e --- /dev/null +++ b/plugin/health/overloaded_test.go @@ -0,0 +1,49 @@ +package health + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func Test_health_overloaded_cancellation(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + h := &health{ + Addr: ts.URL, + stop: cancel, + } + + var err error + h.healthURI, err = url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + h.healthURI.Path = "/health" + + stopped := make(chan struct{}) + go func() { + h.overloaded(ctx) + stopped <- struct{}{} + }() + + // wait for overloaded function to start atleast once + time.Sleep(1 * time.Second) + + cancel() + + select { + case <-stopped: + case <-time.After(5 * time.Second): + t.Fatal("overloaded function should have been cancelled") + } +} diff --git a/plugin/health/setup.go b/plugin/health/setup.go new file mode 100644 index 0000000..e9163ad --- /dev/null +++ b/plugin/health/setup.go @@ -0,0 +1,66 @@ +package health + +import ( + "fmt" + "net" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("health", setup) } + +func setup(c *caddy.Controller) error { + addr, lame, err := parse(c) + if err != nil { + return plugin.Error("health", err) + } + + h := &health{Addr: addr, lameduck: lame} + + c.OnStartup(h.OnStartup) + c.OnRestart(h.OnReload) + c.OnFinalShutdown(h.OnFinalShutdown) + c.OnRestartFailed(h.OnStartup) + + // Don't do AddPlugin, as health is not *really* a plugin just a separate webserver running. + return nil +} + +func parse(c *caddy.Controller) (string, time.Duration, error) { + addr := "" + dur := time.Duration(0) + for c.Next() { + args := c.RemainingArgs() + + switch len(args) { + case 0: + case 1: + addr = args[0] + if _, _, e := net.SplitHostPort(addr); e != nil { + return "", 0, e + } + default: + return "", 0, c.ArgErr() + } + + for c.NextBlock() { + switch c.Val() { + case "lameduck": + args := c.RemainingArgs() + if len(args) != 1 { + return "", 0, c.ArgErr() + } + l, err := time.ParseDuration(args[0]) + if err != nil { + return "", 0, fmt.Errorf("unable to parse lameduck duration value: '%v' : %v", args[0], err) + } + dur = l + default: + return "", 0, c.ArgErr() + } + } + } + return addr, dur, nil +} diff --git a/plugin/health/setup_test.go b/plugin/health/setup_test.go new file mode 100644 index 0000000..7bb2132 --- /dev/null +++ b/plugin/health/setup_test.go @@ -0,0 +1,45 @@ +package health + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupHealth(t *testing.T) { + tests := []struct { + input string + shouldErr bool + }{ + {`health`, false}, + {`health localhost:1234`, false}, + {`health localhost:1234 { + lameduck 4s +}`, false}, + {`health bla:a`, false}, + + {`health bla`, true}, + {`health bla bla`, true}, + {`health localhost:1234 { + lameduck a +}`, true}, + {`health localhost:1234 { + lamedudk 4 +} `, true}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + _, _, err := parse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + } + } +} diff --git a/plugin/hosts/README.md b/plugin/hosts/README.md new file mode 100644 index 0000000..dde6259 --- /dev/null +++ b/plugin/hosts/README.md @@ -0,0 +1,125 @@ +# hosts + +## Name + +*hosts* - enables serving zone data from a `/etc/hosts` style file. + +## Description + +The *hosts* plugin is useful for serving zones from a `/etc/hosts` file. It serves from a preloaded +file that exists on disk. It checks the file for changes and updates the zones accordingly. This +plugin only supports A, AAAA, and PTR records. The hosts plugin can be used with readily +available hosts files that block access to advertising servers. + +The plugin reloads the content of the hosts file every 5 seconds. Upon reload, CoreDNS will use the +new definitions. Should the file be deleted, any inlined content will continue to be served. When +the file is restored, it will then again be used. + +If you want to pass the request to the rest of the plugin chain if there is no match in the *hosts* +plugin, you must specify the `fallthrough` option. + +This plugin can only be used once per Server Block. + +## The hosts file + +Commonly the entries are of the form `IP_address canonical_hostname [aliases...]` as explained by +the hosts(5) man page. + +Examples: + +~~~ +# The following lines are desirable for IPv4 capable hosts +127.0.0.1 localhost +192.168.1.10 example.com example + +# The following lines are desirable for IPv6 capable hosts +::1 localhost ip6-localhost ip6-loopback +fdfc:a744:27b5:3b0e::1 example.com example +~~~ + +### PTR records + +PTR records for reverse lookups are generated automatically by CoreDNS (based on the hosts file +entries) and cannot be created manually. + +## Syntax + +~~~ +hosts [FILE [ZONES...]] { + [INLINE] + ttl SECONDS + no_reverse + reload DURATION + fallthrough [ZONES...] +} +~~~ + +* **FILE** the hosts file to read and parse. If the path is relative the path from the *root* + plugin will be prepended to it. Defaults to /etc/hosts if omitted. We scan the file for changes + every 5 seconds. +* **ZONES** zones it should be authoritative for. If empty, the zones from the configuration block + are used. +* **INLINE** the hosts file contents inlined in Corefile. If there are any lines before fallthrough + then all of them will be treated as the additional content for hosts file. The specified hosts + file path will still be read but entries will be overridden. +* `ttl` change the DNS TTL of the records generated (forward and reverse). The default is 3600 seconds (1 hour). +* `reload` change the period between each hostsfile reload. A time of zero seconds disables the + feature. Examples of valid durations: "300ms", "1.5h" or "2h45m". See Go's + [time](https://godoc.org/time). package. +* `no_reverse` disable the automatic generation of the `in-addr.arpa` or `ip6.arpa` entries for the hosts +* `fallthrough` If zone matches and no record can be generated, pass request to the next plugin. + If **[ZONES...]** is omitted, then fallthrough happens for all zones for which the plugin + is authoritative. If specific zones are listed (for example `in-addr.arpa` and `ip6.arpa`), then only + queries for those zones will be subject to fallthrough. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metrics are exported: + +- `coredns_hosts_entries{}` - The combined number of entries in hosts and Corefile. +- `coredns_hosts_reload_timestamp_seconds{}` - The timestamp of the last reload of hosts file. + +## Examples + +Load `/etc/hosts` file. + +~~~ corefile +. { + hosts +} +~~~ + +Load `example.hosts` file in the current directory. + +~~~ +. { + hosts example.hosts +} +~~~ + +Load example.hosts file and only serve example.org and example.net from it and fall through to the +next plugin if query doesn't match. + +~~~ +. { + hosts example.hosts example.org example.net { + fallthrough + } +} +~~~ + +Load hosts file inlined in Corefile. + +~~~ +example.hosts example.org { + hosts { + 10.0.0.1 example.org + fallthrough + } + whoami +} +~~~ + +## See also + +The form of the entries in the `/etc/hosts` file are based on IETF [RFC 952](https://tools.ietf.org/html/rfc952) which was updated by IETF [RFC 1123](https://tools.ietf.org/html/rfc1123). diff --git a/plugin/hosts/hosts.go b/plugin/hosts/hosts.go new file mode 100644 index 0000000..5c644e7 --- /dev/null +++ b/plugin/hosts/hosts.go @@ -0,0 +1,122 @@ +package hosts + +import ( + "context" + "net" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Hosts is the plugin handler +type Hosts struct { + Next plugin.Handler + *Hostsfile + + Fall fall.F +} + +// ServeDNS implements the plugin.Handle interface. +func (h Hosts) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + + answers := []dns.RR{} + + zone := plugin.Zones(h.Origins).Matches(qname) + if zone == "" { + // PTR zones don't need to be specified in Origins. + if state.QType() != dns.TypePTR { + // if this doesn't match we need to fall through regardless of h.Fallthrough + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + } + + switch state.QType() { + case dns.TypePTR: + names := h.LookupStaticAddr(dnsutil.ExtractAddressFromReverse(qname)) + if len(names) == 0 { + // If this doesn't match we need to fall through regardless of h.Fallthrough + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + answers = h.ptr(qname, h.options.ttl, names) + case dns.TypeA: + ips := h.LookupStaticHostV4(qname) + answers = a(qname, h.options.ttl, ips) + case dns.TypeAAAA: + ips := h.LookupStaticHostV6(qname) + answers = aaaa(qname, h.options.ttl, ips) + } + + // Only on NXDOMAIN we will fallthrough. + if len(answers) == 0 && !h.otherRecordsExist(qname) { + if h.Fall.Through(qname) { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + // We want to send an NXDOMAIN, but because of /etc/hosts' setup we don't have a SOA, so we make it SERVFAIL + // to at least give an answer back to signals we're having problems resolving this. + return dns.RcodeServerFailure, nil + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + m.Answer = answers + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +func (h Hosts) otherRecordsExist(qname string) bool { + if len(h.LookupStaticHostV4(qname)) > 0 { + return true + } + if len(h.LookupStaticHostV6(qname)) > 0 { + return true + } + return false +} + +// Name implements the plugin.Handle interface. +func (h Hosts) Name() string { return "hosts" } + +// a takes a slice of net.IPs and returns a slice of A RRs. +func a(zone string, ttl uint32, ips []net.IP) []dns.RR { + answers := make([]dns.RR, len(ips)) + for i, ip := range ips { + r := new(dns.A) + r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl} + r.A = ip + answers[i] = r + } + return answers +} + +// aaaa takes a slice of net.IPs and returns a slice of AAAA RRs. +func aaaa(zone string, ttl uint32, ips []net.IP) []dns.RR { + answers := make([]dns.RR, len(ips)) + for i, ip := range ips { + r := new(dns.AAAA) + r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl} + r.AAAA = ip + answers[i] = r + } + return answers +} + +// ptr takes a slice of host names and filters out the ones that aren't in Origins, if specified, and returns a slice of PTR RRs. +func (h *Hosts) ptr(zone string, ttl uint32, names []string) []dns.RR { + answers := make([]dns.RR, len(names)) + for i, n := range names { + r := new(dns.PTR) + r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: ttl} + r.Ptr = dns.Fqdn(n) + answers[i] = r + } + return answers +} diff --git a/plugin/hosts/hosts_test.go b/plugin/hosts/hosts_test.go new file mode 100644 index 0000000..320655a --- /dev/null +++ b/plugin/hosts/hosts_test.go @@ -0,0 +1,120 @@ +package hosts + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestLookupA(t *testing.T) { + for _, tc := range hostsTestCases { + m := tc.Msg() + + var tcFall fall.F + isFall := tc.Qname == "fallthrough-example.org." + if isFall { + tcFall = fall.Root + } else { + tcFall = fall.Zero + } + + h := Hosts{ + Next: test.NextHandler(dns.RcodeNameError, nil), + Hostsfile: &Hostsfile{ + Origins: []string{"."}, + hmap: newMap(), + inline: newMap(), + options: newOptions(), + }, + Fall: tcFall, + } + h.hmap = h.parse(strings.NewReader(hostsExample)) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + rcode, err := h.ServeDNS(context.Background(), rec, m) + if err != nil { + t.Errorf("Expected no error, got %v", err) + return + } + + if isFall && tc.Rcode != rcode { + t.Errorf("Expected rcode is %d, but got %d", tc.Rcode, rcode) + return + } + + if resp := rec.Msg; rec.Msg != nil { + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } + } +} + +var hostsTestCases = []test.Case{ + { + Qname: "example.org.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("example.org. 3600 IN A 10.0.0.1"), + }, + }, + { + Qname: "example.com.", Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("example.com. 3600 IN A 10.0.0.2"), + }, + }, + { + Qname: "localhost.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{ + test.AAAA("localhost. 3600 IN AAAA ::1"), + }, + }, + { + Qname: "1.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR, + Answer: []dns.RR{ + test.PTR("1.0.0.10.in-addr.arpa. 3600 PTR example.org."), + }, + }, + { + Qname: "2.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR, + Answer: []dns.RR{ + test.PTR("2.0.0.10.in-addr.arpa. 3600 PTR example.com."), + }, + }, + { + Qname: "1.0.0.127.in-addr.arpa.", Qtype: dns.TypePTR, + Answer: []dns.RR{ + test.PTR("1.0.0.127.in-addr.arpa. 3600 PTR localhost."), + test.PTR("1.0.0.127.in-addr.arpa. 3600 PTR localhost.domain."), + }, + }, + { + Qname: "example.org.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{}, + }, + { + Qname: "example.org.", Qtype: dns.TypeMX, + Answer: []dns.RR{}, + }, + { + Qname: "fallthrough-example.org.", Qtype: dns.TypeAAAA, + Answer: []dns.RR{}, Rcode: dns.RcodeSuccess, + }, +} + +const hostsExample = ` +127.0.0.1 localhost localhost.domain +::1 localhost localhost.domain +10.0.0.1 example.org +::FFFF:10.0.0.2 example.com +10.0.0.3 fallthrough-example.org +reload 5s +timeout 3600 +` diff --git a/plugin/hosts/hostsfile.go b/plugin/hosts/hostsfile.go new file mode 100644 index 0000000..e5aff0d --- /dev/null +++ b/plugin/hosts/hostsfile.go @@ -0,0 +1,259 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is a modified version of net/hosts.go from the golang repo + +package hosts + +import ( + "bufio" + "bytes" + "io" + "net" + "os" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin" +) + +// parseIP calls discards any v6 zone info, before calling net.ParseIP. +func parseIP(addr string) net.IP { + if i := strings.Index(addr, "%"); i >= 0 { + // discard ipv6 zone + addr = addr[0:i] + } + + return net.ParseIP(addr) +} + +type options struct { + // automatically generate IP to Hostname PTR entries + // for host entries we parse + autoReverse bool + + // The TTL of the record we generate + ttl uint32 + + // The time between two reload of the configuration + reload time.Duration +} + +func newOptions() *options { + return &options{ + autoReverse: true, + ttl: 3600, + reload: 5 * time.Second, + } +} + +// Map contains the IPv4/IPv6 and reverse mapping. +type Map struct { + // Key for the list of literal IP addresses must be a FQDN lowercased host name. + name4 map[string][]net.IP + name6 map[string][]net.IP + + // Key for the list of host names must be a literal IP address + // including IPv6 address without zone identifier. + // We don't support old-classful IP address notation. + addr map[string][]string +} + +func newMap() *Map { + return &Map{ + name4: make(map[string][]net.IP), + name6: make(map[string][]net.IP), + addr: make(map[string][]string), + } +} + +// Len returns the total number of addresses in the hostmap, this includes V4/V6 and any reverse addresses. +func (h *Map) Len() int { + l := 0 + for _, v4 := range h.name4 { + l += len(v4) + } + for _, v6 := range h.name6 { + l += len(v6) + } + for _, a := range h.addr { + l += len(a) + } + return l +} + +// Hostsfile contains known host entries. +type Hostsfile struct { + sync.RWMutex + + // list of zones we are authoritative for + Origins []string + + // hosts maps for lookups + hmap *Map + + // inline saves the hosts file that is inlined in a Corefile. + inline *Map + + // path to the hosts file + path string + + // mtime and size are only read and modified by a single goroutine + mtime time.Time + size int64 + + options *options +} + +// readHosts determines if the cached data needs to be updated based on the size and modification time of the hostsfile. +func (h *Hostsfile) readHosts() { + file, err := os.Open(h.path) + if err != nil { + // We already log a warning if the file doesn't exist or can't be opened on setup. No need to return the error here. + return + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return + } + h.RLock() + size := h.size + h.RUnlock() + + if h.mtime.Equal(stat.ModTime()) && size == stat.Size() { + return + } + + newMap := h.parse(file) + log.Debugf("Parsed hosts file into %d entries", newMap.Len()) + + h.Lock() + + h.hmap = newMap + // Update the data cache. + h.mtime = stat.ModTime() + h.size = stat.Size() + + hostsEntries.WithLabelValues().Set(float64(h.inline.Len() + h.hmap.Len())) + hostsReloadTime.Set(float64(stat.ModTime().UnixNano()) / 1e9) + h.Unlock() +} + +func (h *Hostsfile) initInline(inline []string) { + if len(inline) == 0 { + return + } + + h.inline = h.parse(strings.NewReader(strings.Join(inline, "\n"))) +} + +// Parse reads the hostsfile and populates the byName and addr maps. +func (h *Hostsfile) parse(r io.Reader) *Map { + hmap := newMap() + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Bytes() + if i := bytes.Index(line, []byte{'#'}); i >= 0 { + // Discard comments. + line = line[0:i] + } + f := bytes.Fields(line) + if len(f) < 2 { + continue + } + addr := parseIP(string(f[0])) + if addr == nil { + continue + } + + family := 0 + if addr.To4() != nil { + family = 1 + } else { + family = 2 + } + + for i := 1; i < len(f); i++ { + name := plugin.Name(string(f[i])).Normalize() + if plugin.Zones(h.Origins).Matches(name) == "" { + // name is not in Origins + continue + } + switch family { + case 1: + hmap.name4[name] = append(hmap.name4[name], addr) + case 2: + hmap.name6[name] = append(hmap.name6[name], addr) + default: + continue + } + if !h.options.autoReverse { + continue + } + hmap.addr[addr.String()] = append(hmap.addr[addr.String()], name) + } + } + + return hmap +} + +// lookupStaticHost looks up the IP addresses for the given host from the hosts file. +func (h *Hostsfile) lookupStaticHost(m map[string][]net.IP, host string) []net.IP { + h.RLock() + defer h.RUnlock() + + if len(m) == 0 { + return nil + } + + ips, ok := m[host] + if !ok { + return nil + } + ipsCp := make([]net.IP, len(ips)) + copy(ipsCp, ips) + return ipsCp +} + +// LookupStaticHostV4 looks up the IPv4 addresses for the given host from the hosts file. +func (h *Hostsfile) LookupStaticHostV4(host string) []net.IP { + host = strings.ToLower(host) + ip1 := h.lookupStaticHost(h.hmap.name4, host) + ip2 := h.lookupStaticHost(h.inline.name4, host) + return append(ip1, ip2...) +} + +// LookupStaticHostV6 looks up the IPv6 addresses for the given host from the hosts file. +func (h *Hostsfile) LookupStaticHostV6(host string) []net.IP { + host = strings.ToLower(host) + ip1 := h.lookupStaticHost(h.hmap.name6, host) + ip2 := h.lookupStaticHost(h.inline.name6, host) + return append(ip1, ip2...) +} + +// LookupStaticAddr looks up the hosts for the given address from the hosts file. +func (h *Hostsfile) LookupStaticAddr(addr string) []string { + addr = parseIP(addr).String() + if addr == "" { + return nil + } + + h.RLock() + defer h.RUnlock() + hosts1 := h.hmap.addr[addr] + hosts2 := h.inline.addr[addr] + + if len(hosts1) == 0 && len(hosts2) == 0 { + return nil + } + + hostsCp := make([]string, len(hosts1)+len(hosts2)) + copy(hostsCp, hosts1) + copy(hostsCp[len(hosts1):], hosts2) + return hostsCp +} diff --git a/plugin/hosts/hostsfile_test.go b/plugin/hosts/hostsfile_test.go new file mode 100644 index 0000000..05b064e --- /dev/null +++ b/plugin/hosts/hostsfile_test.go @@ -0,0 +1,241 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hosts + +import ( + "net" + "reflect" + "strings" + "testing" + + "github.com/coredns/coredns/plugin" +) + +func testHostsfile(file string) *Hostsfile { + h := &Hostsfile{ + Origins: []string{"."}, + hmap: newMap(), + inline: newMap(), + options: newOptions(), + } + h.hmap = h.parse(strings.NewReader(file)) + return h +} + +type staticHostEntry struct { + in string + v4 []string + v6 []string +} + +var ( + hosts = `255.255.255.255 broadcasthost + 127.0.0.2 odin + 127.0.0.3 odin # inline comment + ::2 odin + 127.1.1.1 thor + # aliases + 127.1.1.2 ullr ullrhost + fe80::1%lo0 localhost + # Bogus entries that must be ignored. + 123.123.123 loki + 321.321.321.321` + singlelinehosts = `127.0.0.2 odin` + ipv4hosts = `# See https://tools.ietf.org/html/rfc1123. + # + + # internet address and host name + 127.0.0.1 localhost # inline comment separated by tab + 127.0.0.2 localhost # inline comment separated by space + + # internet address, host name and aliases + 127.0.0.3 localhost localhost.localdomain` + ipv6hosts = `# See https://tools.ietf.org/html/rfc5952, https://tools.ietf.org/html/rfc4007. + + # internet address and host name + ::1 localhost # inline comment separated by tab + fe80:0000:0000:0000:0000:0000:0000:0001 localhost # inline comment separated by space + + # internet address with zone identifier and host name + fe80:0000:0000:0000:0000:0000:0000:0002%lo0 localhost + + # internet address, host name and aliases + fe80::3%lo0 localhost localhost.localdomain` + casehosts = `127.0.0.1 PreserveMe PreserveMe.local + ::1 PreserveMe PreserveMe.local` +) + +var lookupStaticHostTests = []struct { + file string + ents []staticHostEntry +}{ + { + hosts, + []staticHostEntry{ + {"odin.", []string{"127.0.0.2", "127.0.0.3"}, []string{"::2"}}, + {"thor.", []string{"127.1.1.1"}, []string{}}, + {"ullr.", []string{"127.1.1.2"}, []string{}}, + {"ullrhost.", []string{"127.1.1.2"}, []string{}}, + {"localhost.", []string{}, []string{"fe80::1"}}, + }, + }, + { + singlelinehosts, // see golang.org/issue/6646 + []staticHostEntry{ + {"odin.", []string{"127.0.0.2"}, []string{}}, + }, + }, + { + ipv4hosts, + []staticHostEntry{ + {"localhost.", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, []string{}}, + {"localhost.localdomain.", []string{"127.0.0.3"}, []string{}}, + }, + }, + { + ipv6hosts, + []staticHostEntry{ + {"localhost.", []string{}, []string{"::1", "fe80::1", "fe80::2", "fe80::3"}}, + {"localhost.localdomain.", []string{}, []string{"fe80::3"}}, + }, + }, + { + casehosts, + []staticHostEntry{ + {"PreserveMe.", []string{"127.0.0.1"}, []string{"::1"}}, + {"PreserveMe.local.", []string{"127.0.0.1"}, []string{"::1"}}, + }, + }, +} + +func TestLookupStaticHost(t *testing.T) { + for _, tt := range lookupStaticHostTests { + h := testHostsfile(tt.file) + for _, ent := range tt.ents { + testStaticHost(t, ent, h) + } + } +} + +func testStaticHost(t *testing.T, ent staticHostEntry, h *Hostsfile) { + ins := []string{ent.in, plugin.Name(ent.in).Normalize(), strings.ToLower(ent.in), strings.ToUpper(ent.in)} + for k, in := range ins { + addrsV4 := h.LookupStaticHostV4(in) + if len(addrsV4) != len(ent.v4) { + t.Fatalf("%d, lookupStaticHostV4(%s) = %v; want %v", k, in, addrsV4, ent.v4) + } + for i, v4 := range addrsV4 { + if v4.String() != ent.v4[i] { + t.Fatalf("%d, lookupStaticHostV4(%s) = %v; want %v", k, in, addrsV4, ent.v4) + } + } + addrsV6 := h.LookupStaticHostV6(in) + if len(addrsV6) != len(ent.v6) { + t.Fatalf("%d, lookupStaticHostV6(%s) = %v; want %v", k, in, addrsV6, ent.v6) + } + for i, v6 := range addrsV6 { + if v6.String() != ent.v6[i] { + t.Fatalf("%d, lookupStaticHostV6(%s) = %v; want %v", k, in, addrsV6, ent.v6) + } + } + } +} + +type staticIPEntry struct { + in string + out []string +} + +var lookupStaticAddrTests = []struct { + file string + ents []staticIPEntry +}{ + { + hosts, + []staticIPEntry{ + {"255.255.255.255", []string{"broadcasthost."}}, + {"127.0.0.2", []string{"odin."}}, + {"127.0.0.3", []string{"odin."}}, + {"::2", []string{"odin."}}, + {"127.1.1.1", []string{"thor."}}, + {"127.1.1.2", []string{"ullr.", "ullrhost."}}, + {"fe80::1", []string{"localhost."}}, + }, + }, + { + singlelinehosts, // see golang.org/issue/6646 + []staticIPEntry{ + {"127.0.0.2", []string{"odin."}}, + }, + }, + { + ipv4hosts, // see golang.org/issue/8996 + []staticIPEntry{ + {"127.0.0.1", []string{"localhost."}}, + {"127.0.0.2", []string{"localhost."}}, + {"127.0.0.3", []string{"localhost.", "localhost.localdomain."}}, + }, + }, + { + ipv6hosts, // see golang.org/issue/8996 + []staticIPEntry{ + {"::1", []string{"localhost."}}, + {"fe80::1", []string{"localhost."}}, + {"fe80::2", []string{"localhost."}}, + {"fe80::3", []string{"localhost.", "localhost.localdomain."}}, + }, + }, + { + casehosts, // see golang.org/issue/12806 + []staticIPEntry{ + {"127.0.0.1", []string{"PreserveMe.", "PreserveMe.local."}}, + {"::1", []string{"PreserveMe.", "PreserveMe.local."}}, + }, + }, +} + +func TestLookupStaticAddr(t *testing.T) { + for _, tt := range lookupStaticAddrTests { + h := testHostsfile(tt.file) + for _, ent := range tt.ents { + testStaticAddr(t, ent, h) + } + } +} + +func testStaticAddr(t *testing.T, ent staticIPEntry, h *Hostsfile) { + hosts := h.LookupStaticAddr(ent.in) + for i := range ent.out { + ent.out[i] = plugin.Name(ent.out[i]).Normalize() + } + if !reflect.DeepEqual(hosts, ent.out) { + t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", h.path, ent.in, hosts, h) + } +} + +func TestHostCacheModification(t *testing.T) { + // Ensure that programs can't modify the internals of the host cache. + // See https://github.com/golang/go/issues/14212. + + h := testHostsfile(ipv4hosts) + ent := staticHostEntry{"localhost.", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, []string{}} + testStaticHost(t, ent, h) + // Modify the addresses return by lookupStaticHost. + addrs := h.LookupStaticHostV6(ent.in) + for i := range addrs { + addrs[i] = net.IPv4zero + } + testStaticHost(t, ent, h) + + h = testHostsfile(ipv6hosts) + entip := staticIPEntry{"::1", []string{"localhost."}} + testStaticAddr(t, entip, h) + // Modify the hosts return by lookupStaticAddr. + hosts := h.LookupStaticAddr(entip.in) + for i := range hosts { + hosts[i] += "junk" + } + testStaticAddr(t, entip, h) +} diff --git a/plugin/hosts/log_test.go b/plugin/hosts/log_test.go new file mode 100644 index 0000000..e784bd6 --- /dev/null +++ b/plugin/hosts/log_test.go @@ -0,0 +1,5 @@ +package hosts + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/hosts/metrics.go b/plugin/hosts/metrics.go new file mode 100644 index 0000000..f97497b --- /dev/null +++ b/plugin/hosts/metrics.go @@ -0,0 +1,25 @@ +package hosts + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // hostsEntries is the combined number of entries in hosts and Corefile. + hostsEntries = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "hosts", + Name: "entries", + Help: "The combined number of entries in hosts and Corefile.", + }, []string{}) + // hostsReloadTime is the timestamp of the last reload of hosts file. + hostsReloadTime = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "hosts", + Name: "reload_timestamp_seconds", + Help: "The timestamp of the last reload of hosts file.", + }) +) diff --git a/plugin/hosts/setup.go b/plugin/hosts/setup.go new file mode 100644 index 0000000..128a365 --- /dev/null +++ b/plugin/hosts/setup.go @@ -0,0 +1,158 @@ +package hosts + +import ( + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("hosts") + +func init() { plugin.Register("hosts", setup) } + +func periodicHostsUpdate(h *Hosts) chan bool { + parseChan := make(chan bool) + + if h.options.reload == 0 { + return parseChan + } + + go func() { + ticker := time.NewTicker(h.options.reload) + defer ticker.Stop() + for { + select { + case <-parseChan: + return + case <-ticker.C: + h.readHosts() + } + } + }() + return parseChan +} + +func setup(c *caddy.Controller) error { + h, err := hostsParse(c) + if err != nil { + return plugin.Error("hosts", err) + } + + parseChan := periodicHostsUpdate(&h) + + c.OnStartup(func() error { + h.readHosts() + return nil + }) + + c.OnShutdown(func() error { + close(parseChan) + return nil + }) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + h.Next = next + return h + }) + + return nil +} + +func hostsParse(c *caddy.Controller) (Hosts, error) { + config := dnsserver.GetConfig(c) + + h := Hosts{ + Hostsfile: &Hostsfile{ + path: "/etc/hosts", + hmap: newMap(), + inline: newMap(), + options: newOptions(), + }, + } + + inline := []string{} + i := 0 + for c.Next() { + if i > 0 { + return h, plugin.ErrOnce + } + i++ + + args := c.RemainingArgs() + + if len(args) >= 1 { + h.path = args[0] + args = args[1:] + + if !filepath.IsAbs(h.path) && config.Root != "" { + h.path = filepath.Join(config.Root, h.path) + } + s, err := os.Stat(h.path) + if err != nil { + if os.IsNotExist(err) { + log.Warningf("File does not exist: %s", h.path) + } else { + return h, c.Errf("unable to access hosts file '%s': %v", h.path, err) + } + } + if s != nil && s.IsDir() { + log.Warningf("Hosts file %q is a directory", h.path) + } + } + + h.Origins = plugin.OriginsFromArgsOrServerBlock(args, c.ServerBlockKeys) + + for c.NextBlock() { + switch c.Val() { + case "fallthrough": + h.Fall.SetZonesFromArgs(c.RemainingArgs()) + case "no_reverse": + h.options.autoReverse = false + case "ttl": + remaining := c.RemainingArgs() + if len(remaining) < 1 { + return h, c.Errf("ttl needs a time in second") + } + ttl, err := strconv.Atoi(remaining[0]) + if err != nil { + return h, c.Errf("ttl needs a number of second") + } + if ttl <= 0 || ttl > 65535 { + return h, c.Errf("ttl provided is invalid") + } + h.options.ttl = uint32(ttl) + case "reload": + remaining := c.RemainingArgs() + if len(remaining) != 1 { + return h, c.Errf("reload needs a duration (zero seconds to disable)") + } + reload, err := time.ParseDuration(remaining[0]) + if err != nil { + return h, c.Errf("invalid duration for reload '%s'", remaining[0]) + } + if reload < 0 { + return h, c.Errf("invalid negative duration for reload '%s'", remaining[0]) + } + h.options.reload = reload + default: + if len(h.Fall.Zones) == 0 { + line := strings.Join(append([]string{c.Val()}, c.RemainingArgs()...), " ") + inline = append(inline, line) + continue + } + return h, c.Errf("unknown property '%s'", c.Val()) + } + } + } + + h.initInline(inline) + + return h, nil +} diff --git a/plugin/hosts/setup_test.go b/plugin/hosts/setup_test.go new file mode 100644 index 0000000..38c7c31 --- /dev/null +++ b/plugin/hosts/setup_test.go @@ -0,0 +1,169 @@ +package hosts + +import ( + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/fall" +) + +func TestHostsParse(t *testing.T) { + tests := []struct { + inputFileRules string + shouldErr bool + expectedPath string + expectedOrigins []string + expectedFallthrough fall.F + }{ + { + `hosts +`, + false, "/etc/hosts", nil, fall.Zero, + }, + { + `hosts /tmp`, + false, "/tmp", nil, fall.Zero, + }, + { + `hosts /etc/hosts miek.nl.`, + false, "/etc/hosts", []string{"miek.nl."}, fall.Zero, + }, + { + `hosts /etc/hosts miek.nl. pun.gent.`, + false, "/etc/hosts", []string{"miek.nl.", "pun.gent."}, fall.Zero, + }, + { + `hosts { + fallthrough + }`, + false, "/etc/hosts", nil, fall.Root, + }, + { + `hosts /tmp { + fallthrough + }`, + false, "/tmp", nil, fall.Root, + }, + { + `hosts /etc/hosts miek.nl. { + fallthrough + }`, + false, "/etc/hosts", []string{"miek.nl."}, fall.Root, + }, + { + `hosts /etc/hosts miek.nl 10.0.0.9/8 { + fallthrough + }`, + false, "/etc/hosts", []string{"miek.nl.", "10.in-addr.arpa."}, fall.Root, + }, + { + `hosts /etc/hosts { + fallthrough + } + hosts /etc/hosts { + fallthrough + }`, + true, "/etc/hosts", nil, fall.Root, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + h, err := hostsParse(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } else if !test.shouldErr { + if h.path != test.expectedPath { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedPath, h.path) + } + } else { + if !h.Fall.Equal(test.expectedFallthrough) { + t.Fatalf("Test %d expected fallthrough of %v, got %v", i, test.expectedFallthrough, h.Fall) + } + if len(h.Origins) != len(test.expectedOrigins) { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedOrigins, h.Origins) + } + for j, name := range test.expectedOrigins { + if h.Origins[j] != name { + t.Fatalf("Test %d expected %v for %d th zone, got %v", i, name, j, h.Origins[j]) + } + } + } + } +} + +func TestHostsInlineParse(t *testing.T) { + tests := []struct { + inputFileRules string + shouldErr bool + expectedaddr map[string][]string + expectedFallthrough fall.F + }{ + { + `hosts highly_unlikely_to_exist_hosts_file example.org { + 10.0.0.1 example.org + fallthrough + }`, + false, + map[string][]string{ + `10.0.0.1`: { + `example.org.`, + }, + }, + fall.Root, + }, + { + `hosts highly_unlikely_to_exist_hosts_file example.org { + 10.0.0.1 example.org + }`, + false, + map[string][]string{ + `10.0.0.1`: { + `example.org.`, + }, + }, + fall.Zero, + }, + { + `hosts highly_unlikely_to_exist_hosts_file example.org { + fallthrough + 10.0.0.1 example.org + }`, + true, + map[string][]string{}, + fall.Root, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + h, err := hostsParse(c) + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } else if !test.shouldErr { + if !h.Fall.Equal(test.expectedFallthrough) { + t.Errorf("Test %d expected fallthrough of %v, got %v", i, test.expectedFallthrough, h.Fall) + } + for k, expectedVal := range test.expectedaddr { + val, ok := h.inline.addr[k] + if !ok { + t.Errorf("Test %d expected %v, got no entry", i, k) + continue + } + if len(expectedVal) != len(val) { + t.Errorf("Test %d expected %v records for %v, got %v", i, len(expectedVal), k, len(val)) + } + for j := range expectedVal { + if expectedVal[j] != val[j] { + t.Errorf("Test %d expected %v for %v, got %v", i, expectedVal[j], j, val[j]) + } + } + } + } + } +} diff --git a/plugin/import/README.md b/plugin/import/README.md new file mode 100644 index 0000000..aaaaa1b --- /dev/null +++ b/plugin/import/README.md @@ -0,0 +1,73 @@ +# import + +## Name + +*import* - includes files or references snippets from a Corefile. + +## Description + +The *import* plugin can be used to include files into the main configuration. Another use is to +reference predefined snippets. Both can help to avoid some duplication. + +This is a unique plugin in that *import* can appear outside of a server block. In other words, it +can appear at the top of a Corefile where an address would normally be. + +## Syntax + +~~~ +import PATTERN +~~~ + +* **PATTERN** is the file, glob pattern (`*`) or snippet to include. Its contents will replace + this line, as if that file's contents appeared here to begin with. + +## Files + +You can use *import* to include a file or files. This file's location is relative to the +Corefile's location. It is an error if a specific file cannot be found, but an empty glob pattern is +not an error. + +## Snippets + +You can define snippets to be reused later in your Corefile by defining a block with a single-token +label surrounded by parentheses: + +~~~ corefile +(mysnippet) { + ... +} +~~~ + +Then you can invoke the snippet with *import*: + +~~~ +import mysnippet +~~~ + +## Examples + +Import a shared configuration: + +~~~ +. { + import config/common.conf +} +~~~ + +Where `config/common.conf` contains: + +~~~ +prometheus +errors +log +~~~ + +This imports files found in the zones directory: + +~~~ +import ../zones/* +~~~ + +## See Also + +See corefile(5). diff --git a/plugin/k8s_external/README.md b/plugin/k8s_external/README.md new file mode 100644 index 0000000..893a131 --- /dev/null +++ b/plugin/k8s_external/README.md @@ -0,0 +1,132 @@ +# k8s_external + +## Name + +*k8s_external* - resolves load balancer, external IPs from outside Kubernetes clusters and if enabled headless services. + +## Description + +This plugin allows an additional zone to resolve the external IP address(es) of a Kubernetes +service and headless services. This plugin is only useful if the *kubernetes* plugin is also loaded. + +The plugin uses an external zone to resolve in-cluster IP addresses. It only handles queries for A, +AAAA, SRV, and PTR records; To make it a proper DNS zone, it handles SOA and NS queries for the apex of the zone. + +By default the apex of the zone will look like the following (assuming the zone used is `example.org`): + +~~~ dns +example.org. 5 IN SOA ns1.dns.example.org. hostmaster.example.org. ( + 12345 ; serial + 14400 ; refresh (4 hours) + 3600 ; retry (1 hour) + 604800 ; expire (1 week) + 5 ; minimum (4 hours) + ) +example.org 5 IN NS ns1.dns.example.org. + +ns1.dns.example.org. 5 IN A .... +ns1.dns.example.org. 5 IN AAAA .... +~~~ + +Note that we use the `dns` subdomain for the records DNS needs (see the `apex` directive). Also +note the SOA's serial number is static. The IP addresses of the nameserver records are those of the +CoreDNS service. + +The *k8s_external* plugin handles the subdomain `dns` and the apex of the zone itself; all other +queries are resolved to addresses in the cluster. + +## Syntax + +~~~ +k8s_external [ZONE...] +~~~ + +* **ZONES** zones *k8s_external* should be authoritative for. + +If you want to change the apex domain or use a different TTL for the returned records you can use +this extended syntax. + +~~~ +k8s_external [ZONE...] { + apex APEX + ttl TTL +} +~~~ + +* **APEX** is the name (DNS label) to use for the apex records; it defaults to `dns`. +* `ttl` allows you to set a custom **TTL** for responses. The default is 5 (seconds). + +If you want to enable headless service resolution, you can do so by adding `headless` option. + +~~~ +k8s_external [ZONE...] { + headless +} +~~~ + +* if there is a headless service with external IPs set, external IPs will be resolved + +If the queried domain does not exist, you can fall through to next plugin by adding the `fallthrough` option. + +~~~ +k8s_external [ZONE...] { + fallthrough [ZONE...] +} +~~~ + +## Examples + +Enable names under `example.org` to be resolved to in-cluster DNS addresses. + +~~~ +. { + kubernetes cluster.local + k8s_external example.org +} +~~~ + +With the Corefile above, the following Service will get an `A` record for `test.default.example.org` with the IP address `192.168.200.123`. + +~~~ +apiVersion: v1 +kind: Service +metadata: + name: test + namespace: default +spec: + clusterIP: None + externalIPs: + - 192.168.200.123 + type: ClusterIP +~~~ + +The *k8s_external* plugin can be used in conjunction with the *transfer* plugin to enable +zone transfers. Notifies are not supported. + + ~~~ + . { + transfer example.org { + to * + } + kubernetes cluster.local + k8s_external example.org + } + ~~~ + +With the `fallthrough` option, if the queried domain does not exist, it will be passed to the next plugin that matches the zone. + +~~~ +. { + kubernetes cluster.local + k8s_external example.org { + fallthrough + } + forward . 8.8.8.8 +} +~~~ + +# See Also + +For some background see [resolve external IP address](https://github.com/kubernetes/dns/issues/242). +And [A records for services with Load Balancer IP](https://github.com/coredns/coredns/issues/1851). + diff --git a/plugin/k8s_external/apex.go b/plugin/k8s_external/apex.go new file mode 100644 index 0000000..e575e5e --- /dev/null +++ b/plugin/k8s_external/apex.go @@ -0,0 +1,112 @@ +package external + +import ( + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// serveApex serves request that hit the zone' apex. A reply is written back to the client. +func (e *External) serveApex(state request.Request) (int, error) { + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + switch state.QType() { + case dns.TypeSOA: + m.Answer = []dns.RR{e.soa(state)} + case dns.TypeNS: + m.Answer = []dns.RR{e.ns(state)} + + addr := e.externalAddrFunc(state, e.headless) + for _, rr := range addr { + rr.Header().Ttl = e.ttl + rr.Header().Name = dnsutil.Join("ns1", e.apex, state.QName()) + m.Extra = append(m.Extra, rr) + } + default: + m.Ns = []dns.RR{e.soa(state)} + } + + state.W.WriteMsg(m) + return 0, nil +} + +// serveSubApex serves requests that hit the zones fake 'dns' subdomain where our nameservers live. +func (e *External) serveSubApex(state request.Request) (int, error) { + base, _ := dnsutil.TrimZone(state.Name(), state.Zone) + + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + + // base is either dns. of ns1.dns (or another name), if it's longer return nxdomain + switch labels := dns.CountLabel(base); labels { + default: + m.SetRcode(m, dns.RcodeNameError) + m.Ns = []dns.RR{e.soa(state)} + state.W.WriteMsg(m) + return 0, nil + case 2: + nl, _ := dns.NextLabel(base, 0) + ns := base[:nl] + if ns != "ns1." { + // nxdomain + m.SetRcode(m, dns.RcodeNameError) + m.Ns = []dns.RR{e.soa(state)} + state.W.WriteMsg(m) + return 0, nil + } + + addr := e.externalAddrFunc(state, e.headless) + for _, rr := range addr { + rr.Header().Ttl = e.ttl + rr.Header().Name = state.QName() + switch state.QType() { + case dns.TypeA: + if rr.Header().Rrtype == dns.TypeA { + m.Answer = append(m.Answer, rr) + } + case dns.TypeAAAA: + if rr.Header().Rrtype == dns.TypeAAAA { + m.Answer = append(m.Answer, rr) + } + } + } + + if len(m.Answer) == 0 { + m.Ns = []dns.RR{e.soa(state)} + } + + state.W.WriteMsg(m) + return 0, nil + + case 1: + // nodata for the dns empty non-terminal + m.Ns = []dns.RR{e.soa(state)} + state.W.WriteMsg(m) + return 0, nil + } +} + +func (e *External) soa(state request.Request) *dns.SOA { + header := dns.RR_Header{Name: state.Zone, Rrtype: dns.TypeSOA, Ttl: e.ttl, Class: dns.ClassINET} + + soa := &dns.SOA{Hdr: header, + Mbox: dnsutil.Join(e.hostmaster, e.apex, state.Zone), + Ns: dnsutil.Join("ns1", e.apex, state.Zone), + Serial: e.externalSerialFunc(state.Zone), + Refresh: 7200, + Retry: 1800, + Expire: 86400, + Minttl: e.ttl, + } + return soa +} + +func (e *External) ns(state request.Request) *dns.NS { + header := dns.RR_Header{Name: state.Zone, Rrtype: dns.TypeNS, Ttl: e.ttl, Class: dns.ClassINET} + ns := &dns.NS{Hdr: header, Ns: dnsutil.Join("ns1", e.apex, state.Zone)} + + return ns +} diff --git a/plugin/k8s_external/apex_test.go b/plugin/k8s_external/apex_test.go new file mode 100644 index 0000000..ab08187 --- /dev/null +++ b/plugin/k8s_external/apex_test.go @@ -0,0 +1,122 @@ +package external + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/kubernetes" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestApex(t *testing.T) { + k := kubernetes.New([]string{"cluster.local."}) + k.Namespaces = map[string]struct{}{"testns": {}} + k.APIConn = &external{} + + e := New() + e.headless = true + e.Zones = []string{"example.com."} + e.Next = test.NextHandler(dns.RcodeSuccess, nil) + e.externalFunc = k.External + e.externalAddrFunc = externalAddress // internal test function + e.externalSerialFunc = externalSerial // internal test function + + ctx := context.TODO() + for i, tc := range testsApex { + r := tc.Msg() + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := e.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + if !resp.Authoritative { + t.Error("Expected authoritative answer") + } + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + for i, rr := range tc.Ns { + expectsoa := rr.(*dns.SOA) + gotsoa, ok := resp.Ns[i].(*dns.SOA) + if !ok { + t.Fatalf("Unexpected record type in Authority section") + } + if expectsoa.Serial != gotsoa.Serial { + t.Fatalf("Expected soa serial %d, got %d", expectsoa.Serial, gotsoa.Serial) + } + } + } +} + +var testsApex = []test.Case{ + { + Qname: "example.com.", Qtype: dns.TypeSOA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "example.com.", Qtype: dns.TypeNS, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.NS("example.com. 5 IN NS ns1.dns.example.com."), + }, + Extra: []dns.RR{ + test.A("ns1.dns.example.com. 5 IN A 127.0.0.1"), + }, + }, + { + Qname: "example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "dns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "dns.example.com.", Qtype: dns.TypeNS, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "ns1.dns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "ns1.dns.example.com.", Qtype: dns.TypeNS, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "ns1.dns.example.com.", Qtype: dns.TypeAAAA, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "ns1.dns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("ns1.dns.example.com. 5 IN A 127.0.0.1"), + }, + }, +} diff --git a/plugin/k8s_external/external.go b/plugin/k8s_external/external.go new file mode 100644 index 0000000..442119b --- /dev/null +++ b/plugin/k8s_external/external.go @@ -0,0 +1,126 @@ +/* +Package external implements external names for kubernetes clusters. + +This plugin only handles three qtypes (except the apex queries, because those are handled +differently). We support A, AAAA and SRV request, for all other types we return NODATA or +NXDOMAIN depending on the state of the cluster. + +A plugin willing to provide these services must implement the Externaler interface, although it +likely only makes sense for the *kubernetes* plugin. +*/ +package external + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Externaler defines the interface that a plugin should implement in order to be used by External. +type Externaler interface { + // External returns a slice of msg.Services that are looked up in the backend and match + // the request. + External(request.Request, bool) ([]msg.Service, int) + // ExternalAddress should return a string slice of addresses for the nameserving endpoint. + ExternalAddress(state request.Request, headless bool) []dns.RR + // ExternalServices returns all services in the given zone as a slice of msg.Service and if enabled, headless services as a map of services. + ExternalServices(zone string, headless bool) ([]msg.Service, map[string][]msg.Service) + // ExternalSerial gets the current serial. + ExternalSerial(string) uint32 +} + +// External serves records for External IPs and Loadbalance IPs of Services in Kubernetes clusters. +type External struct { + Next plugin.Handler + Zones []string + Fall fall.F + + hostmaster string + apex string + ttl uint32 + headless bool + + upstream *upstream.Upstream + + externalFunc func(request.Request, bool) ([]msg.Service, int) + externalAddrFunc func(request.Request, bool) []dns.RR + externalSerialFunc func(string) uint32 + externalServicesFunc func(string, bool) ([]msg.Service, map[string][]msg.Service) +} + +// New returns a new and initialized *External. +func New() *External { + e := &External{hostmaster: "hostmaster", ttl: 5, apex: "dns"} + return e +} + +// ServeDNS implements the plugin.Handle interface. +func (e *External) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + zone := plugin.Zones(e.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r) + } + + state.Zone = zone + for _, z := range e.Zones { + // TODO(miek): save this in the External struct. + if state.Name() == z { // apex query + ret, err := e.serveApex(state) + return ret, err + } + if dns.IsSubDomain(e.apex+"."+z, state.Name()) { + // dns subdomain test for ns. and dns. queries + ret, err := e.serveSubApex(state) + return ret, err + } + } + + svc, rcode := e.externalFunc(state, e.headless) + + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + + if len(svc) == 0 { + if e.Fall.Through(state.Name()) && rcode == dns.RcodeNameError { + return plugin.NextOrFailure(e.Name(), e.Next, ctx, w, r) + } + + m.Rcode = rcode + m.Ns = []dns.RR{e.soa(state)} + w.WriteMsg(m) + return 0, nil + } + + switch state.QType() { + case dns.TypeA: + m.Answer, m.Truncated = e.a(ctx, svc, state) + case dns.TypeAAAA: + m.Answer, m.Truncated = e.aaaa(ctx, svc, state) + case dns.TypeSRV: + m.Answer, m.Extra = e.srv(ctx, svc, state) + case dns.TypePTR: + m.Answer = e.ptr(svc, state) + default: + m.Ns = []dns.RR{e.soa(state)} + } + + // If we did have records, but queried for the wrong qtype return a nodata response. + if len(m.Answer) == 0 { + m.Ns = []dns.RR{e.soa(state)} + } + + w.WriteMsg(m) + return 0, nil +} + +// Name implements the Handler interface. +func (e *External) Name() string { return "k8s_external" } diff --git a/plugin/k8s_external/external_test.go b/plugin/k8s_external/external_test.go new file mode 100644 index 0000000..1e3630b --- /dev/null +++ b/plugin/k8s_external/external_test.go @@ -0,0 +1,426 @@ +package external + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/kubernetes" + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" +) + +func TestExternal(t *testing.T) { + k := kubernetes.New([]string{"cluster.local."}) + k.Namespaces = map[string]struct{}{"testns": {}} + k.APIConn = &external{} + + e := New() + e.Zones = []string{"example.com.", "in-addr.arpa."} + e.headless = true + e.Next = test.NextHandler(dns.RcodeSuccess, nil) + e.externalFunc = k.External + e.externalAddrFunc = externalAddress // internal test function + e.externalSerialFunc = externalSerial // internal test function + + ctx := context.TODO() + for i, tc := range tests { + r := tc.Msg() + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := e.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + if !resp.Authoritative { + t.Error("Expected authoritative answer") + } + if err = test.SortAndCheck(resp, tc); err != nil { + t.Errorf("Test %d: %v", i, err) + } + } +} + +var tests = []test.Case{ + // PTR reverse lookup + { + Qname: "4.3.2.1.in-addr.arpa.", Qtype: dns.TypePTR, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("4.3.2.1.in-addr.arpa. 5 IN PTR svc1.testns.example.com."), + }, + }, + // Bad PTR reverse lookup using existing service name + { + Qname: "svc1.testns.example.com.", Qtype: dns.TypePTR, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // Bad PTR reverse lookup using non-existing service name + { + Qname: "not-existing.testns.example.com.", Qtype: dns.TypePTR, Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // A Service + { + Qname: "svc1.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svc1.testns.example.com. 5 IN A 1.2.3.4"), + }, + }, + { + Qname: "svc1.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{test.SRV("svc1.testns.example.com. 5 IN SRV 0 100 80 svc1.testns.example.com.")}, + Extra: []dns.RR{test.A("svc1.testns.example.com. 5 IN A 1.2.3.4")}, + }, + // SRV Service Not udp/tcp + { + Qname: "*._not-udp-or-tcp.svc1.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // SRV Service + { + Qname: "_http._tcp.svc1.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc1.testns.example.com. 5 IN SRV 0 100 80 svc1.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("svc1.testns.example.com. 5 IN A 1.2.3.4"), + }, + }, + // AAAA Service (with an existing A record, but no AAAA record) + { + Qname: "svc1.testns.example.com.", Qtype: dns.TypeAAAA, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // AAAA Service (non-existing service) + { + Qname: "svc0.testns.example.com.", Qtype: dns.TypeAAAA, Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // A Service (non-existing service) + { + Qname: "svc0.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // A Service (non-existing namespace) + { + Qname: "svc0.svc-nons.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // AAAA Service + { + Qname: "svc6.testns.example.com.", Qtype: dns.TypeAAAA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.AAAA("svc6.testns.example.com. 5 IN AAAA 1:2::5"), + }, + }, + // SRV + { + Qname: "_http._tcp.svc6.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc6.testns.example.com. 5 IN SRV 0 100 80 svc6.testns.example.com."), + }, + Extra: []dns.RR{ + test.AAAA("svc6.testns.example.com. 5 IN AAAA 1:2::5"), + }, + }, + // SRV + { + Qname: "svc6.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("svc6.testns.example.com. 5 IN SRV 0 100 80 svc6.testns.example.com."), + }, + Extra: []dns.RR{ + test.AAAA("svc6.testns.example.com. 5 IN AAAA 1:2::5"), + }, + }, + { + Qname: "testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "testns.example.com.", Qtype: dns.TypeSOA, Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.example.com. 1499347823 7200 1800 86400 5"), + }, + }, + // svc11 + { + Qname: "svc11.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svc11.testns.example.com. 5 IN A 2.3.4.5"), + }, + }, + { + Qname: "_http._tcp.svc11.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc11.testns.example.com. 5 IN SRV 0 100 80 svc11.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("svc11.testns.example.com. 5 IN A 2.3.4.5"), + }, + }, + { + Qname: "svc11.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("svc11.testns.example.com. 5 IN SRV 0 100 80 svc11.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("svc11.testns.example.com. 5 IN A 2.3.4.5"), + }, + }, + // svc12 + { + Qname: "svc12.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.CNAME("svc12.testns.example.com. 5 IN CNAME dummy.hostname"), + }, + }, + { + Qname: "_http._tcp.svc12.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc12.testns.example.com. 5 IN SRV 0 100 80 dummy.hostname."), + }, + }, + { + Qname: "svc12.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("svc12.testns.example.com. 5 IN SRV 0 100 80 dummy.hostname."), + }, + }, + // headless service + { + Qname: "svc-headless.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svc-headless.testns.example.com. 5 IN A 1.2.3.4"), + test.A("svc-headless.testns.example.com. 5 IN A 1.2.3.5"), + }, + }, + { + Qname: "svc-headless.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("svc-headless.testns.example.com. 5 IN SRV 0 50 80 endpoint-svc-0.svc-headless.testns.example.com."), + test.SRV("svc-headless.testns.example.com. 5 IN SRV 0 50 80 endpoint-svc-1.svc-headless.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("endpoint-svc-0.svc-headless.testns.example.com. 5 IN A 1.2.3.4"), + test.A("endpoint-svc-1.svc-headless.testns.example.com. 5 IN A 1.2.3.5"), + }, + }, + { + Qname: "_http._tcp.svc-headless.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc-headless.testns.example.com. 5 IN SRV 0 50 80 endpoint-svc-0.svc-headless.testns.example.com."), + test.SRV("_http._tcp.svc-headless.testns.example.com. 5 IN SRV 0 50 80 endpoint-svc-1.svc-headless.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("endpoint-svc-0.svc-headless.testns.example.com. 5 IN A 1.2.3.4"), + test.A("endpoint-svc-1.svc-headless.testns.example.com. 5 IN A 1.2.3.5"), + }, + }, + { + Qname: "endpoint-svc-0.svc-headless.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("endpoint-svc-0.svc-headless.testns.example.com. 5 IN SRV 0 100 80 endpoint-svc-0.svc-headless.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("endpoint-svc-0.svc-headless.testns.example.com. 5 IN A 1.2.3.4"), + }, + }, + { + Qname: "endpoint-svc-1.svc-headless.testns.example.com.", Qtype: dns.TypeSRV, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("endpoint-svc-1.svc-headless.testns.example.com. 5 IN SRV 0 100 80 endpoint-svc-1.svc-headless.testns.example.com."), + }, + Extra: []dns.RR{ + test.A("endpoint-svc-1.svc-headless.testns.example.com. 5 IN A 1.2.3.5"), + }, + }, + { + Qname: "endpoint-svc-0.svc-headless.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("endpoint-svc-0.svc-headless.testns.example.com. 5 IN A 1.2.3.4"), + }, + }, + { + Qname: "endpoint-svc-1.svc-headless.testns.example.com.", Qtype: dns.TypeA, Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("endpoint-svc-1.svc-headless.testns.example.com. 5 IN A 1.2.3.5"), + }, + }, +} + +type external struct{} + +func (external) HasSynced() bool { return true } +func (external) Run() {} +func (external) Stop() error { return nil } +func (external) EpIndexReverse(string) []*object.Endpoints { return nil } +func (external) SvcIndexReverse(string) []*object.Service { return nil } +func (external) Modified(bool) int64 { return 0 } +func (external) EpIndex(s string) []*object.Endpoints { + return epIndexExternal[s] +} +func (external) EndpointsList() []*object.Endpoints { + var eps []*object.Endpoints + for _, ep := range epIndexExternal { + eps = append(eps, ep...) + } + return eps +} +func (external) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { return nil, nil } +func (external) SvcIndex(s string) []*object.Service { return svcIndexExternal[s] } +func (external) PodIndex(string) []*object.Pod { return nil } + +func (external) SvcExtIndexReverse(ip string) (result []*object.Service) { + for _, svcs := range svcIndexExternal { + for _, svc := range svcs { + for _, exIp := range svc.ExternalIPs { + if exIp != ip { + continue + } + result = append(result, svc) + } + } + } + return result +} + +func (external) GetNamespaceByName(name string) (*object.Namespace, error) { + return &object.Namespace{ + Name: name, + }, nil +} + +var epIndexExternal = map[string][]*object.Endpoints{ + "svc-headless.testns": { + { + Name: "svc-headless", + Namespace: "testns", + Index: "svc-headless.testns", + Subsets: []object.EndpointSubset{ + { + Ports: []object.EndpointPort{ + { + Port: 80, + Name: "http", + Protocol: "TCP", + }, + }, + Addresses: []object.EndpointAddress{ + { + IP: "1.2.3.4", + Hostname: "endpoint-svc-0", + NodeName: "test-node", + TargetRefName: "endpoint-svc-0", + }, + { + IP: "1.2.3.5", + Hostname: "endpoint-svc-1", + NodeName: "test-node", + TargetRefName: "endpoint-svc-1", + }, + }, + }, + }, + }, + }, +} + +var svcIndexExternal = map[string][]*object.Service{ + "svc1.testns": { + { + Name: "svc1", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.1"}, + ExternalIPs: []string{"1.2.3.4"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, + "svc6.testns": { + { + Name: "svc6", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.3"}, + ExternalIPs: []string{"1:2::5"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, + "svc11.testns": { + { + Name: "svc11", + Namespace: "testns", + Type: api.ServiceTypeLoadBalancer, + ExternalIPs: []string{"2.3.4.5"}, + ClusterIPs: []string{"10.0.0.3"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, + "svc12.testns": { + { + Name: "svc12", + Namespace: "testns", + Type: api.ServiceTypeLoadBalancer, + ClusterIPs: []string{"10.0.0.3"}, + ExternalIPs: []string{"dummy.hostname"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, + "svc-headless.testns": { + { + Name: "svc-headless", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"None"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, +} + +func (external) ServiceList() []*object.Service { + var svcs []*object.Service + for _, svc := range svcIndexExternal { + svcs = append(svcs, svc...) + } + return svcs +} + +func externalAddress(state request.Request, headless bool) []dns.RR { + a := test.A("example.org. IN A 127.0.0.1") + return []dns.RR{a} +} + +func externalSerial(string) uint32 { + return 1499347823 +} diff --git a/plugin/k8s_external/msg_to_dns.go b/plugin/k8s_external/msg_to_dns.go new file mode 100644 index 0000000..6975718 --- /dev/null +++ b/plugin/k8s_external/msg_to_dns.go @@ -0,0 +1,190 @@ +package external + +import ( + "context" + "math" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func (e *External) a(ctx context.Context, services []msg.Service, state request.Request) (records []dns.RR, truncated bool) { + dup := make(map[string]struct{}) + + for _, s := range services { + what, ip := s.HostType() + + switch what { + case dns.TypeCNAME: + rr := s.NewCNAME(state.QName(), s.Host) + records = append(records, rr) + if resp, err := e.upstream.Lookup(ctx, state, dns.Fqdn(s.Host), dns.TypeA); err == nil { + records = append(records, resp.Answer...) + if resp.Truncated { + truncated = true + } + } + + case dns.TypeA: + if _, ok := dup[s.Host]; !ok { + dup[s.Host] = struct{}{} + rr := s.NewA(state.QName(), ip) + rr.Hdr.Ttl = e.ttl + records = append(records, rr) + } + + case dns.TypeAAAA: + // nada + } + } + return records, truncated +} + +func (e *External) aaaa(ctx context.Context, services []msg.Service, state request.Request) (records []dns.RR, truncated bool) { + dup := make(map[string]struct{}) + + for _, s := range services { + what, ip := s.HostType() + + switch what { + case dns.TypeCNAME: + rr := s.NewCNAME(state.QName(), s.Host) + records = append(records, rr) + if resp, err := e.upstream.Lookup(ctx, state, dns.Fqdn(s.Host), dns.TypeAAAA); err == nil { + records = append(records, resp.Answer...) + if resp.Truncated { + truncated = true + } + } + + case dns.TypeA: + // nada + + case dns.TypeAAAA: + if _, ok := dup[s.Host]; !ok { + dup[s.Host] = struct{}{} + rr := s.NewAAAA(state.QName(), ip) + rr.Hdr.Ttl = e.ttl + records = append(records, rr) + } + } + } + return records, truncated +} + +func (e *External) ptr(services []msg.Service, state request.Request) (records []dns.RR) { + dup := make(map[string]struct{}) + for _, s := range services { + if _, ok := dup[s.Host]; !ok { + dup[s.Host] = struct{}{} + rr := s.NewPTR(state.QName(), dnsutil.Join(s.Host, e.Zones[0])) + rr.Hdr.Ttl = e.ttl + records = append(records, rr) + } + } + return records +} + +func (e *External) srv(ctx context.Context, services []msg.Service, state request.Request) (records, extra []dns.RR) { + dup := make(map[item]struct{}) + + // Looping twice to get the right weight vs priority. This might break because we may drop duplicate SRV records latter on. + w := make(map[int]int) + for _, s := range services { + weight := 100 + if s.Weight != 0 { + weight = s.Weight + } + if _, ok := w[s.Priority]; !ok { + w[s.Priority] = weight + continue + } + w[s.Priority] += weight + } + for _, s := range services { + // Don't add the entry if the port is -1 (invalid). The kubernetes plugin uses port -1 when a service/endpoint + // does not have any declared ports. + if s.Port == -1 { + continue + } + w1 := 100.0 / float64(w[s.Priority]) + if s.Weight == 0 { + w1 *= 100 + } else { + w1 *= float64(s.Weight) + } + weight := uint16(math.Floor(w1)) + // weight should be at least 1 + if weight == 0 { + weight = 1 + } + + what, ip := s.HostType() + + switch what { + case dns.TypeCNAME: + addr := dns.Fqdn(s.Host) + srv := s.NewSRV(state.QName(), weight) + if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok { + records = append(records, srv) + } + if ok := isDuplicate(dup, srv.Target, addr, 0); !ok { + if resp, err := e.upstream.Lookup(ctx, state, addr, dns.TypeA); err == nil { + extra = append(extra, resp.Answer...) + } + if resp, err := e.upstream.Lookup(ctx, state, addr, dns.TypeAAAA); err == nil { + extra = append(extra, resp.Answer...) + } + } + case dns.TypeA, dns.TypeAAAA: + addr := s.Host + s.Host = msg.Domain(s.Key) + srv := s.NewSRV(state.QName(), weight) + + if ok := isDuplicate(dup, srv.Target, "", srv.Port); !ok { + records = append(records, srv) + } + + if ok := isDuplicate(dup, srv.Target, addr, 0); !ok { + hdr := dns.RR_Header{Name: srv.Target, Rrtype: what, Class: dns.ClassINET, Ttl: e.ttl} + + switch what { + case dns.TypeA: + extra = append(extra, &dns.A{Hdr: hdr, A: ip}) + case dns.TypeAAAA: + extra = append(extra, &dns.AAAA{Hdr: hdr, AAAA: ip}) + } + } + } + } + return records, extra +} + +// not sure if this is even needed. + +// item holds records. +type item struct { + name string // name of the record (either owner or something else unique). + port uint16 // port of the record (used for address records, A and AAAA). + addr string // address of the record (A and AAAA). +} + +// isDuplicate uses m to see if the combo (name, addr, port) already exists. If it does +// not exist already IsDuplicate will also add the record to the map. +func isDuplicate(m map[item]struct{}, name, addr string, port uint16) bool { + if addr != "" { + _, ok := m[item{name, 0, addr}] + if !ok { + m[item{name, 0, addr}] = struct{}{} + } + return ok + } + _, ok := m[item{name, port, ""}] + if !ok { + m[item{name, port, ""}] = struct{}{} + } + return ok +} diff --git a/plugin/k8s_external/setup.go b/plugin/k8s_external/setup.go new file mode 100644 index 0000000..f42f7de --- /dev/null +++ b/plugin/k8s_external/setup.go @@ -0,0 +1,88 @@ +package external + +import ( + "errors" + "strconv" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/upstream" +) + +const pluginName = "k8s_external" + +func init() { plugin.Register(pluginName, setup) } + +func setup(c *caddy.Controller) error { + e, err := parse(c) + if err != nil { + return plugin.Error("k8s_external", err) + } + + // Do this in OnStartup, so all plugins have been initialized. + c.OnStartup(func() error { + m := dnsserver.GetConfig(c).Handler("kubernetes") + if m == nil { + return plugin.Error(pluginName, errors.New("kubernetes plugin not loaded")) + } + + x, ok := m.(Externaler) + if !ok { + return plugin.Error(pluginName, errors.New("kubernetes plugin does not implement the Externaler interface")) + } + + e.externalFunc = x.External + e.externalAddrFunc = x.ExternalAddress + e.externalServicesFunc = x.ExternalServices + e.externalSerialFunc = x.ExternalSerial + return nil + }) + + e.upstream = upstream.New() + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + e.Next = next + return e + }) + + return nil +} + +func parse(c *caddy.Controller) (*External, error) { + e := New() + + for c.Next() { // external + e.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + for c.NextBlock() { + switch c.Val() { + case "ttl": + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + t, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + if t < 0 || t > 3600 { + return nil, c.Errf("ttl must be in range [0, 3600]: %d", t) + } + e.ttl = uint32(t) + case "apex": + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + e.apex = args[0] + case "headless": + e.headless = true + case "fallthrough": + e.Fall.SetZonesFromArgs(c.RemainingArgs()) + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + } + return e, nil +} diff --git a/plugin/k8s_external/setup_test.go b/plugin/k8s_external/setup_test.go new file mode 100644 index 0000000..8814554 --- /dev/null +++ b/plugin/k8s_external/setup_test.go @@ -0,0 +1,71 @@ +package external + +import ( + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/fall" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedZone string + expectedApex string + expectedHeadless bool + expectedFallthrough fall.F + }{ + {`k8s_external`, false, "", "dns", false, fall.Zero}, + {`k8s_external example.org`, false, "example.org.", "dns", false, fall.Zero}, + {`k8s_external example.org { + apex testdns +}`, false, "example.org.", "testdns", false, fall.Zero}, + {`k8s_external example.org { + headless +}`, false, "example.org.", "dns", true, fall.Zero}, + {`k8s_external example.org { + fallthrough +}`, false, "example.org.", "dns", false, fall.Root}, + {`k8s_external example.org { + fallthrough ip6.arpa inaddr.arpa foo.com +}`, false, "example.org.", "dns", false, + fall.F{Zones: []string{"ip6.arpa.", "inaddr.arpa.", "foo.com."}}}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + e, err := parse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + } + + if !test.shouldErr && test.expectedZone != "" { + if test.expectedZone != e.Zones[0] { + t.Errorf("Test %d, expected zone %q for input %s, got: %q", i, test.expectedZone, test.input, e.Zones[0]) + } + } + if !test.shouldErr { + if test.expectedApex != e.apex { + t.Errorf("Test %d, expected apex %q for input %s, got: %q", i, test.expectedApex, test.input, e.apex) + } + } + if !test.shouldErr { + if test.expectedHeadless != e.headless { + t.Errorf("Test %d, expected headless %q for input %s, got: %v", i, test.expectedApex, test.input, e.headless) + } + } + if !test.shouldErr { + if !e.Fall.Equal(test.expectedFallthrough) { + t.Errorf("Test %d, expected to be initialized with fallthrough %q for input %s, got: %v", i, test.expectedFallthrough, test.input, e.Fall) + } + } + } +} diff --git a/plugin/k8s_external/transfer.go b/plugin/k8s_external/transfer.go new file mode 100644 index 0000000..781f19f --- /dev/null +++ b/plugin/k8s_external/transfer.go @@ -0,0 +1,150 @@ +package external + +import ( + "context" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/kubernetes" + "github.com/coredns/coredns/plugin/transfer" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Transfer implements transfer.Transferer +func (e *External) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + z := plugin.Zones(e.Zones).Matches(zone) + if z != zone { + return nil, transfer.ErrNotAuthoritative + } + + ctx := context.Background() + ch := make(chan []dns.RR, 2) + if zone == "." { + zone = "" + } + state := request.Request{Zone: zone} + + // SOA + soa := e.soa(state) + ch <- []dns.RR{soa} + if serial != 0 && serial >= soa.Serial { + close(ch) + return ch, nil + } + + go func() { + // Add NS + nsName := "ns1." + e.apex + "." + zone + nsHdr := dns.RR_Header{Name: zone, Rrtype: dns.TypeNS, Ttl: e.ttl, Class: dns.ClassINET} + ch <- []dns.RR{&dns.NS{Hdr: nsHdr, Ns: nsName}} + + // Add Nameserver A/AAAA records + nsRecords := e.externalAddrFunc(state, e.headless) + for i := range nsRecords { + // externalAddrFunc returns incomplete header names, correct here + nsRecords[i].Header().Name = nsName + nsRecords[i].Header().Ttl = e.ttl + ch <- []dns.RR{nsRecords[i]} + } + + svcs, headlessSvcs := e.externalServicesFunc(zone, e.headless) + srvSeen := make(map[string]struct{}) + + for i := range svcs { + name := msg.Domain(svcs[i].Key) + + if svcs[i].TargetStrip == 0 { + // Add Service A/AAAA records + s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}} + as, _ := e.a(ctx, []msg.Service{svcs[i]}, s) + if len(as) > 0 { + ch <- as + } + aaaas, _ := e.aaaa(ctx, []msg.Service{svcs[i]}, s) + if len(aaaas) > 0 { + ch <- aaaas + } + // Add bare SRV record, ensuring uniqueness + recs, _ := e.srv(ctx, []msg.Service{svcs[i]}, s) + for _, srv := range recs { + if !nameSeen(srvSeen, srv) { + ch <- []dns.RR{srv} + } + } + continue + } + // Add full SRV record, ensuring uniqueness + s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}} + recs, _ := e.srv(ctx, []msg.Service{svcs[i]}, s) + for _, srv := range recs { + if !nameSeen(srvSeen, srv) { + ch <- []dns.RR{srv} + } + } + } + for key, svcs := range headlessSvcs { + // we have to strip the leading key because it's either port.protocol or endpoint + name := msg.Domain(key[:strings.LastIndex(key, "/")]) + switchKey := key[strings.LastIndex(key, "/")+1:] + switch switchKey { + case kubernetes.Endpoint: + // headless.namespace.example.com records + s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}} + as, _ := e.a(ctx, svcs, s) + if len(as) > 0 { + ch <- as + } + aaaas, _ := e.aaaa(ctx, svcs, s) + if len(aaaas) > 0 { + ch <- aaaas + } + // Add bare SRV record, ensuring uniqueness + recs, _ := e.srv(ctx, svcs, s) + ch <- recs + for _, srv := range recs { + ch <- []dns.RR{srv} + } + + for i := range svcs { + // endpoint.headless.namespace.example.com record + s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: msg.Domain(svcs[i].Key)}}}} + + as, _ := e.a(ctx, []msg.Service{svcs[i]}, s) + if len(as) > 0 { + ch <- as + } + aaaas, _ := e.aaaa(ctx, []msg.Service{svcs[i]}, s) + if len(aaaas) > 0 { + ch <- aaaas + } + // Add bare SRV record, ensuring uniqueness + recs, _ := e.srv(ctx, []msg.Service{svcs[i]}, s) + ch <- recs + for _, srv := range recs { + ch <- []dns.RR{srv} + } + } + + case kubernetes.PortProtocol: + s := request.Request{Req: &dns.Msg{Question: []dns.Question{{Name: name}}}} + recs, _ := e.srv(ctx, svcs, s) + ch <- recs + } + } + ch <- []dns.RR{soa} + close(ch) + }() + + return ch, nil +} + +func nameSeen(namesSeen map[string]struct{}, rr dns.RR) bool { + if _, duplicate := namesSeen[rr.Header().Name]; duplicate { + return true + } + namesSeen[rr.Header().Name] = struct{}{} + return false +} diff --git a/plugin/k8s_external/transfer_test.go b/plugin/k8s_external/transfer_test.go new file mode 100644 index 0000000..4f525f9 --- /dev/null +++ b/plugin/k8s_external/transfer_test.go @@ -0,0 +1,148 @@ +package external + +import ( + "strings" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/kubernetes" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/plugin/transfer" + + "github.com/miekg/dns" +) + +func TestImplementsTransferer(t *testing.T) { + var e plugin.Handler + e = &External{} + _, ok := e.(transfer.Transferer) + if !ok { + t.Error("Transferer not implemented") + } +} + +func TestTransferAXFR(t *testing.T) { + k := kubernetes.New([]string{"cluster.local."}) + k.Namespaces = map[string]struct{}{"testns": {}} + k.APIConn = &external{} + + e := New() + e.headless = true + e.Zones = []string{"example.com."} + e.externalFunc = k.External + e.externalAddrFunc = externalAddress // internal test function + e.externalSerialFunc = externalSerial // internal test function + e.externalServicesFunc = k.ExternalServices + + ch, err := e.Transfer("example.com.", 0) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + var records []dns.RR + for rrs := range ch { + records = append(records, rrs...) + } + + expect := []dns.RR{} + for _, tc := range append(tests, testsApex...) { + if tc.Rcode != dns.RcodeSuccess { + continue + } + + for _, ans := range tc.Answer { + // Exclude wildcard test cases + if strings.Contains(ans.Header().Name, "*") { + continue + } + + // Exclude TXT records + if ans.Header().Rrtype == dns.TypeTXT { + continue + } + + // Exclude PTR records + if ans.Header().Rrtype == dns.TypePTR { + continue + } + + expect = append(expect, ans) + } + } + + diff := difference(expect, records) + if len(diff) != 0 { + t.Errorf("Got back %d records that do not exist in test cases, should be 0:", len(diff)) + for _, rec := range diff { + t.Errorf("%+v", rec) + } + } + + diff = difference(records, expect) + if len(diff) != 0 { + t.Errorf("Result is missing %d records, should be 0:", len(diff)) + for _, rec := range diff { + t.Errorf("%+v", rec) + } + } +} + +func TestTransferIXFR(t *testing.T) { + k := kubernetes.New([]string{"cluster.local."}) + k.Namespaces = map[string]struct{}{"testns": {}} + k.APIConn = &external{} + + e := New() + e.Zones = []string{"example.com."} + e.headless = true + e.externalFunc = k.External + e.externalAddrFunc = externalAddress // internal test function + e.externalSerialFunc = externalSerial // internal test function + e.externalServicesFunc = k.ExternalServices + + ch, err := e.Transfer("example.com.", externalSerial("example.com.")) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + var records []dns.RR + for rrs := range ch { + records = append(records, rrs...) + } + + expect := []dns.RR{ + test.SOA("example.com. 5 IN SOA ns1.dns.example.com. hostmaster.dns.example.com. 1499347823 7200 1800 86400 5"), + } + + diff := difference(expect, records) + if len(diff) != 0 { + t.Errorf("Got back %d records that do not exist in test cases, should be 0:", len(diff)) + for _, rec := range diff { + t.Errorf("%+v", rec) + } + } + + diff = difference(records, expect) + if len(diff) != 0 { + t.Errorf("Result is missing %d records, should be 0:", len(diff)) + for _, rec := range diff { + t.Errorf("%+v", rec) + } + } +} + +// difference shows what we're missing when comparing two RR slices +func difference(testRRs []dns.RR, gotRRs []dns.RR) []dns.RR { + expectedRRs := map[string]struct{}{} + for _, rr := range testRRs { + expectedRRs[rr.String()] = struct{}{} + } + + foundRRs := []dns.RR{} + for _, rr := range gotRRs { + if _, ok := expectedRRs[rr.String()]; !ok { + foundRRs = append(foundRRs, rr) + } + } + return foundRRs +} diff --git a/plugin/kubernetes/README.md b/plugin/kubernetes/README.md new file mode 100644 index 0000000..0c50333 --- /dev/null +++ b/plugin/kubernetes/README.md @@ -0,0 +1,242 @@ +# kubernetes + +## Name + +*kubernetes* - enables reading zone data from a Kubernetes cluster. + +## Description + +This plugin implements the [Kubernetes DNS-Based Service Discovery +Specification](https://github.com/kubernetes/dns/blob/master/docs/specification.md). + +CoreDNS running the kubernetes plugin can be used as a replacement for kube-dns in a kubernetes +cluster. See the [deployment](https://github.com/coredns/deployment) repository for details on [how +to deploy CoreDNS in Kubernetes](https://github.com/coredns/deployment/tree/master/kubernetes). + +[stubDomains and upstreamNameservers](https://kubernetes.io/blog/2017/04/configuring-private-dns-zones-upstream-nameservers-kubernetes/) +are implemented via the *forward* plugin. See the examples below. + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ +kubernetes [ZONES...] +~~~ + +With only the plugin specified, the *kubernetes* plugin will default to the zone specified in +the server's block. It will handle all queries in that zone and connect to Kubernetes in-cluster. It +will not provide PTR records for services or A records for pods. If **ZONES** is used it specifies +all the zones the plugin should be authoritative for. + +``` +kubernetes [ZONES...] { + endpoint URL + tls CERT KEY CACERT + kubeconfig KUBECONFIG [CONTEXT] + namespaces NAMESPACE... + labels EXPRESSION + pods POD-MODE + endpoint_pod_names + ttl TTL + noendpoints + fallthrough [ZONES...] + ignore empty_service +} +``` + +* `endpoint` specifies the **URL** for a remote k8s API endpoint. + If omitted, it will connect to k8s in-cluster using the cluster service account. +* `tls` **CERT** **KEY** **CACERT** are the TLS cert, key and the CA cert file names for remote k8s connection. + This option is ignored if connecting in-cluster (i.e. endpoint is not specified). +* `kubeconfig` **KUBECONFIG [CONTEXT]** authenticates the connection to a remote k8s cluster using a kubeconfig file. + **[CONTEXT]** is optional, if not set, then the current context specified in kubeconfig will be used. + It supports TLS, username and password, or token-based authentication. + This option is ignored if connecting in-cluster (i.e., the endpoint is not specified). +* `namespaces` **NAMESPACE [NAMESPACE...]** only exposes the k8s namespaces listed. + If this option is omitted all namespaces are exposed +* `namespace_labels` **EXPRESSION** only expose the records for Kubernetes namespaces that match this label selector. + The label selector syntax is described in the + [Kubernetes User Guide - Labels](https://kubernetes.io/docs/user-guide/labels/). An example that + only exposes namespaces labeled as "istio-injection=enabled", would use: + `labels istio-injection=enabled`. +* `labels` **EXPRESSION** only exposes the records for Kubernetes objects that match this label selector. + The label selector syntax is described in the + [Kubernetes User Guide - Labels](https://kubernetes.io/docs/user-guide/labels/). An example that + only exposes objects labeled as "application=nginx" in the "staging" or "qa" environments, would + use: `labels environment in (staging, qa),application=nginx`. +* `pods` **POD-MODE** sets the mode for handling IP-based pod A records, e.g. + `1-2-3-4.ns.pod.cluster.local. in A 1.2.3.4`. + This option is provided to facilitate use of SSL certs when connecting directly to pods. Valid + values for **POD-MODE**: + + * `disabled`: Default. Do not process pod requests, always returning `NXDOMAIN` + * `insecure`: Always return an A record with IP from request (without checking k8s). This option + is vulnerable to abuse if used maliciously in conjunction with wildcard SSL certs. This + option is provided for backward compatibility with kube-dns. + * `verified`: Return an A record if there exists a pod in same namespace with matching IP. This + option requires substantially more memory than in insecure mode, since it will maintain a watch + on all pods. + +* `endpoint_pod_names` uses the pod name of the pod targeted by the endpoint as + the endpoint name in A records, e.g., + `endpoint-name.my-service.namespace.svc.cluster.local. in A 1.2.3.4` + By default, the endpoint-name name selection is as follows: Use the hostname + of the endpoint, or if hostname is not set, use the dashed form of the endpoint + IP address (e.g., `1-2-3-4.my-service.namespace.svc.cluster.local.`) + If this directive is included, then name selection for endpoints changes as + follows: Use the hostname of the endpoint, or if hostname is not set, use the + pod name of the pod targeted by the endpoint. If there is no pod targeted by + the endpoint or pod name is longer than 63, use the dashed IP address form. +* `ttl` allows you to set a custom TTL for responses. The default is 5 seconds. The minimum TTL allowed is + 0 seconds, and the maximum is capped at 3600 seconds. Setting TTL to 0 will prevent records from being cached. +* `noendpoints` will turn off the serving of endpoint records by disabling the watch on endpoints. + All endpoint queries and headless service queries will result in an NXDOMAIN. +* `fallthrough` **[ZONES...]** If a query for a record in the zones for which the plugin is authoritative + results in NXDOMAIN, normally that is what the response will be. However, if you specify this option, + the query will instead be passed on down the plugin chain, which can include another plugin to handle + the query. If **[ZONES...]** is omitted, then fallthrough happens for all zones for which the plugin + is authoritative. If specific zones are listed (for example `in-addr.arpa` and `ip6.arpa`), then only + queries for those zones will be subject to fallthrough. +* `ignore empty_service` returns NXDOMAIN for services without any ready endpoint addresses (e.g., ready pods). + This allows the querying pod to continue searching for the service in the search path. + The search path could, for example, include another Kubernetes cluster. + +Enabling zone transfer is done by using the *transfer* plugin. + +## Startup + +When CoreDNS starts with the *kubernetes* plugin enabled, it will delay serving DNS for up to 5 seconds +until it can connect to the Kubernetes API and synchronize all object watches. If this cannot happen within +5 seconds, then CoreDNS will start serving DNS while the *kubernetes* plugin continues to try to connect +and synchronize all object watches. CoreDNS will answer SERVFAIL to any request made for a Kubernetes record +that has not yet been synchronized. + +## Monitoring Kubernetes Endpoints + +The *kubernetes* plugin watches Endpoints via the `discovery.EndpointSlices` API. + +## Ready + +This plugin reports readiness to the ready plugin. This will happen after it has synced to the +Kubernetes API. + +## Examples + +Handle all queries in the `cluster.local` zone. Connect to Kubernetes in-cluster. Also handle all +`in-addr.arpa` `PTR` requests for `10.0.0.0/17` . Verify the existence of pods when answering pod +requests. + +~~~ txt +10.0.0.0/17 cluster.local { + kubernetes { + pods verified + } +} +~~~ + +Or you can selectively expose some namespaces: + +~~~ txt +kubernetes cluster.local { + namespaces test staging +} +~~~ + +Connect to Kubernetes with CoreDNS running outside the cluster: + +~~~ txt +kubernetes cluster.local { + endpoint https://k8s-endpoint:8443 + tls cert key cacert +} +~~~ + +## stubDomains and upstreamNameservers + +Here we use the *forward* plugin to implement a stubDomain that forwards `example.local` to the nameserver `10.100.0.10:53`. +Also configured is an upstreamNameserver `8.8.8.8:53` that will be used for resolving names that do not fall in `cluster.local` +or `example.local`. + +~~~ txt +cluster.local:53 { + kubernetes cluster.local +} +example.local { + forward . 10.100.0.10:53 +} + +. { + forward . 8.8.8.8:53 +} +~~~ + +The configuration above represents the following Kube-DNS stubDomains and upstreamNameservers configuration. + +~~~ txt +stubDomains: | + {“example.local”: [“10.100.0.10:53”]} +upstreamNameservers: | + [“8.8.8.8:53”] +~~~ + +## AutoPath + +The *kubernetes* plugin can be used in conjunction with the *autopath* plugin. Using this +feature enables server-side domain search path completion in Kubernetes clusters. Note: `pods` must +be set to `verified` for this to function properly. Furthermore, the remote IP address in the DNS +packet received by CoreDNS must be the IP address of the Pod that sent the request. + + cluster.local { + autopath @kubernetes + kubernetes { + pods verified + } + } + +## Metadata + +The kubernetes plugin will publish the following metadata, if the *metadata* +plugin is also enabled: + + * `kubernetes/endpoint`: the endpoint name in the query + * `kubernetes/kind`: the resource kind (pod or svc) in the query + * `kubernetes/namespace`: the namespace in the query + * `kubernetes/port-name`: the port name in an SRV query + * `kubernetes/protocol`: the protocol in an SRV query + * `kubernetes/service`: the service name in the query + * `kubernetes/client-namespace`: the client pod's namespace (see requirements below) + * `kubernetes/client-pod-name`: the client pod's name (see requirements below) + +The `kubernetes/client-namespace` and `kubernetes/client-pod-name` metadata work by reconciling the +client IP address in the DNS request packet to a known pod IP address. Therefore the following is required: + * `pods verified` mode must be enabled + * the remote IP address in the DNS packet received by CoreDNS must be the IP address + of the Pod that sent the request. + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metrics are exported: + +* `coredns_kubernetes_dns_programming_duration_seconds{service_kind}` - Exports the + [DNS programming latency SLI](https://github.com/kubernetes/community/blob/master/sig-scalability/slos/dns_programming_latency.md). + The metrics has the `service_kind` label that identifies the kind of the + [kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service). + It may take one of the three values: + * `cluster_ip` + * `headless_with_selector` + * `headless_without_selector` + +The following are client level metrics to monitor apiserver request latency & status codes. `verb` identifies the apiserver [request type](https://kubernetes.io/docs/reference/using-api/api-concepts/#single-resource-api) and `host` denotes the apiserver endpoint. +* `coredns_kubernetes_rest_client_request_duration_seconds{verb, host}` - captures apiserver request latency perceived by client grouped by `verb` and `host`. +* `coredns_kubernetes_rest_client_rate_limiter_duration_seconds{verb, host}` - captures apiserver request latency contributed by client side rate limiter grouped by `verb` & `host`. +* `coredns_kubernetes_rest_client_requests_total{method, code, host}` - captures total apiserver requests grouped by `method`, `status_code` & `host`. + +## Bugs + +The duration metric only supports the "headless\_with\_selector" service currently. + +## See Also + +See the *autopath* plugin to enable search path optimizations. And use the *transfer* plugin to +enable outgoing zone transfers. diff --git a/plugin/kubernetes/autopath.go b/plugin/kubernetes/autopath.go new file mode 100644 index 0000000..e873897 --- /dev/null +++ b/plugin/kubernetes/autopath.go @@ -0,0 +1,62 @@ +package kubernetes + +import ( + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/request" +) + +// AutoPath implements the AutoPathFunc call from the autopath plugin. +// It returns a per-query search path or nil indicating no searchpathing should happen. +func (k *Kubernetes) AutoPath(state request.Request) []string { + // Check if the query falls in a zone we are actually authoritative for and thus if we want autopath. + zone := plugin.Zones(k.Zones).Matches(state.Name()) + if zone == "" { + return nil + } + + // cluster.local { + // autopath @kubernetes + // kubernetes { + // pods verified # + // } + // } + // if pods != verified will cause panic and return SERVFAIL, expect worked as normal without autopath function + if !k.opts.initPodCache { + return nil + } + + ip := state.IP() + + pod := k.podWithIP(ip) + if pod == nil { + return nil + } + + search := make([]string, 3) + if zone == "." { + search[0] = pod.Namespace + ".svc." + search[1] = "svc." + search[2] = "." + } else { + search[0] = pod.Namespace + ".svc." + zone + search[1] = "svc." + zone + search[2] = zone + } + + search = append(search, k.autoPathSearch...) + search = append(search, "") // sentinel + return search +} + +// podWithIP returns the api.Pod for source IP. It returns nil if nothing can be found. +func (k *Kubernetes) podWithIP(ip string) *object.Pod { + if k.podMode != podModeVerified { + return nil + } + ps := k.APIConn.PodIndex(ip) + if len(ps) == 0 { + return nil + } + return ps[0] +} diff --git a/plugin/kubernetes/controller.go b/plugin/kubernetes/controller.go new file mode 100644 index 0000000..e7db294 --- /dev/null +++ b/plugin/kubernetes/controller.go @@ -0,0 +1,673 @@ +package kubernetes + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/kubernetes/object" + + api "k8s.io/api/core/v1" + discovery "k8s.io/api/discovery/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" +) + +const ( + podIPIndex = "PodIP" + svcNameNamespaceIndex = "ServiceNameNamespace" + svcIPIndex = "ServiceIP" + svcExtIPIndex = "ServiceExternalIP" + epNameNamespaceIndex = "EndpointNameNamespace" + epIPIndex = "EndpointsIP" +) + +type dnsController interface { + ServiceList() []*object.Service + EndpointsList() []*object.Endpoints + SvcIndex(string) []*object.Service + SvcIndexReverse(string) []*object.Service + SvcExtIndexReverse(string) []*object.Service + PodIndex(string) []*object.Pod + EpIndex(string) []*object.Endpoints + EpIndexReverse(string) []*object.Endpoints + + GetNodeByName(context.Context, string) (*api.Node, error) + GetNamespaceByName(string) (*object.Namespace, error) + + Run() + HasSynced() bool + Stop() error + + // Modified returns the timestamp of the most recent changes to services. If the passed bool is true, it should + // return the timestamp of the most recent changes to services with external facing IP addresses + Modified(bool) int64 +} + +type dnsControl struct { + // modified tracks timestamp of the most recent changes + // It needs to be first because it is guaranteed to be 8-byte + // aligned ( we use sync.LoadAtomic with this ) + modified int64 + // extModified tracks timestamp of the most recent changes to + // services with external facing IP addresses + extModified int64 + + client kubernetes.Interface + + selector labels.Selector + namespaceSelector labels.Selector + + svcController cache.Controller + podController cache.Controller + epController cache.Controller + nsController cache.Controller + + svcLister cache.Indexer + podLister cache.Indexer + epLister cache.Indexer + nsLister cache.Store + + // stopLock is used to enforce only a single call to Stop is active. + // Needed because we allow stopping through an http endpoint and + // allowing concurrent stoppers leads to stack traces. + stopLock sync.Mutex + shutdown bool + stopCh chan struct{} + + zones []string + endpointNameMode bool +} + +type dnsControlOpts struct { + initPodCache bool + initEndpointsCache bool + ignoreEmptyService bool + + // Label handling. + labelSelector *meta.LabelSelector + selector labels.Selector + namespaceLabelSelector *meta.LabelSelector + namespaceSelector labels.Selector + + zones []string + endpointNameMode bool +} + +// newdnsController creates a controller for CoreDNS. +func newdnsController(ctx context.Context, kubeClient kubernetes.Interface, opts dnsControlOpts) *dnsControl { + dns := dnsControl{ + client: kubeClient, + selector: opts.selector, + namespaceSelector: opts.namespaceSelector, + stopCh: make(chan struct{}), + zones: opts.zones, + endpointNameMode: opts.endpointNameMode, + } + + dns.svcLister, dns.svcController = object.NewIndexerInformer( + &cache.ListWatch{ + ListFunc: serviceListFunc(ctx, dns.client, api.NamespaceAll, dns.selector), + WatchFunc: serviceWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector), + }, + &api.Service{}, + cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete}, + cache.Indexers{svcNameNamespaceIndex: svcNameNamespaceIndexFunc, svcIPIndex: svcIPIndexFunc, svcExtIPIndex: svcExtIPIndexFunc}, + object.DefaultProcessor(object.ToService, nil), + ) + + podLister, podController := object.NewIndexerInformer( + &cache.ListWatch{ + ListFunc: podListFunc(ctx, dns.client, api.NamespaceAll, dns.selector), + WatchFunc: podWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector), + }, + &api.Pod{}, + cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete}, + cache.Indexers{podIPIndex: podIPIndexFunc}, + object.DefaultProcessor(object.ToPod, nil), + ) + dns.podLister = podLister + if opts.initPodCache { + dns.podController = podController + } + + epLister, epController := object.NewIndexerInformer( + &cache.ListWatch{ + ListFunc: endpointSliceListFunc(ctx, dns.client, api.NamespaceAll, dns.selector), + WatchFunc: endpointSliceWatchFunc(ctx, dns.client, api.NamespaceAll, dns.selector), + }, + &discovery.EndpointSlice{}, + cache.ResourceEventHandlerFuncs{AddFunc: dns.Add, UpdateFunc: dns.Update, DeleteFunc: dns.Delete}, + cache.Indexers{epNameNamespaceIndex: epNameNamespaceIndexFunc, epIPIndex: epIPIndexFunc}, + object.DefaultProcessor(object.EndpointSliceToEndpoints, dns.EndpointSliceLatencyRecorder()), + ) + dns.epLister = epLister + if opts.initEndpointsCache { + dns.epController = epController + } + + dns.nsLister, dns.nsController = object.NewIndexerInformer( + &cache.ListWatch{ + ListFunc: namespaceListFunc(ctx, dns.client, dns.namespaceSelector), + WatchFunc: namespaceWatchFunc(ctx, dns.client, dns.namespaceSelector), + }, + &api.Namespace{}, + cache.ResourceEventHandlerFuncs{}, + cache.Indexers{}, + object.DefaultProcessor(object.ToNamespace, nil), + ) + + return &dns +} + +func (dns *dnsControl) EndpointsLatencyRecorder() *object.EndpointLatencyRecorder { + return &object.EndpointLatencyRecorder{ + ServiceFunc: func(o meta.Object) []*object.Service { + return dns.SvcIndex(object.ServiceKey(o.GetName(), o.GetNamespace())) + }, + } +} +func (dns *dnsControl) EndpointSliceLatencyRecorder() *object.EndpointLatencyRecorder { + return &object.EndpointLatencyRecorder{ + ServiceFunc: func(o meta.Object) []*object.Service { + return dns.SvcIndex(object.ServiceKey(o.GetLabels()[discovery.LabelServiceName], o.GetNamespace())) + }, + } +} + +func podIPIndexFunc(obj interface{}) ([]string, error) { + p, ok := obj.(*object.Pod) + if !ok { + return nil, errObj + } + return []string{p.PodIP}, nil +} + +func svcIPIndexFunc(obj interface{}) ([]string, error) { + svc, ok := obj.(*object.Service) + if !ok { + return nil, errObj + } + idx := make([]string, len(svc.ClusterIPs)) + copy(idx, svc.ClusterIPs) + return idx, nil +} + +func svcExtIPIndexFunc(obj interface{}) ([]string, error) { + svc, ok := obj.(*object.Service) + if !ok { + return nil, errObj + } + idx := make([]string, len(svc.ExternalIPs)) + copy(idx, svc.ExternalIPs) + return idx, nil +} + +func svcNameNamespaceIndexFunc(obj interface{}) ([]string, error) { + s, ok := obj.(*object.Service) + if !ok { + return nil, errObj + } + return []string{s.Index}, nil +} + +func epNameNamespaceIndexFunc(obj interface{}) ([]string, error) { + s, ok := obj.(*object.Endpoints) + if !ok { + return nil, errObj + } + return []string{s.Index}, nil +} + +func epIPIndexFunc(obj interface{}) ([]string, error) { + ep, ok := obj.(*object.Endpoints) + if !ok { + return nil, errObj + } + return ep.IndexIP, nil +} + +func serviceListFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) { + return func(opts meta.ListOptions) (runtime.Object, error) { + if s != nil { + opts.LabelSelector = s.String() + } + return c.CoreV1().Services(ns).List(ctx, opts) + } +} + +func podListFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) { + return func(opts meta.ListOptions) (runtime.Object, error) { + if s != nil { + opts.LabelSelector = s.String() + } + if len(opts.FieldSelector) > 0 { + opts.FieldSelector = opts.FieldSelector + "," + } + opts.FieldSelector = opts.FieldSelector + "status.phase!=Succeeded,status.phase!=Failed,status.phase!=Unknown" + return c.CoreV1().Pods(ns).List(ctx, opts) + } +} + +func endpointSliceListFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) { + return func(opts meta.ListOptions) (runtime.Object, error) { + if s != nil { + opts.LabelSelector = s.String() + } + return c.DiscoveryV1().EndpointSlices(ns).List(ctx, opts) + } +} + +func namespaceListFunc(ctx context.Context, c kubernetes.Interface, s labels.Selector) func(meta.ListOptions) (runtime.Object, error) { + return func(opts meta.ListOptions) (runtime.Object, error) { + if s != nil { + opts.LabelSelector = s.String() + } + return c.CoreV1().Namespaces().List(ctx, opts) + } +} + +func serviceWatchFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) { + return func(options meta.ListOptions) (watch.Interface, error) { + if s != nil { + options.LabelSelector = s.String() + } + return c.CoreV1().Services(ns).Watch(ctx, options) + } +} + +func podWatchFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) { + return func(options meta.ListOptions) (watch.Interface, error) { + if s != nil { + options.LabelSelector = s.String() + } + if len(options.FieldSelector) > 0 { + options.FieldSelector = options.FieldSelector + "," + } + options.FieldSelector = options.FieldSelector + "status.phase!=Succeeded,status.phase!=Failed,status.phase!=Unknown" + return c.CoreV1().Pods(ns).Watch(ctx, options) + } +} + +func endpointSliceWatchFunc(ctx context.Context, c kubernetes.Interface, ns string, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) { + return func(options meta.ListOptions) (watch.Interface, error) { + if s != nil { + options.LabelSelector = s.String() + } + return c.DiscoveryV1().EndpointSlices(ns).Watch(ctx, options) + } +} + +func namespaceWatchFunc(ctx context.Context, c kubernetes.Interface, s labels.Selector) func(options meta.ListOptions) (watch.Interface, error) { + return func(options meta.ListOptions) (watch.Interface, error) { + if s != nil { + options.LabelSelector = s.String() + } + return c.CoreV1().Namespaces().Watch(ctx, options) + } +} + +// Stop stops the controller. +func (dns *dnsControl) Stop() error { + dns.stopLock.Lock() + defer dns.stopLock.Unlock() + + // Only try draining the workqueue if we haven't already. + if !dns.shutdown { + close(dns.stopCh) + dns.shutdown = true + + return nil + } + + return fmt.Errorf("shutdown already in progress") +} + +// Run starts the controller. +func (dns *dnsControl) Run() { + go dns.svcController.Run(dns.stopCh) + if dns.epController != nil { + go func() { + dns.epController.Run(dns.stopCh) + }() + } + if dns.podController != nil { + go dns.podController.Run(dns.stopCh) + } + go dns.nsController.Run(dns.stopCh) + <-dns.stopCh +} + +// HasSynced calls on all controllers. +func (dns *dnsControl) HasSynced() bool { + a := dns.svcController.HasSynced() + b := true + if dns.epController != nil { + b = dns.epController.HasSynced() + } + c := true + if dns.podController != nil { + c = dns.podController.HasSynced() + } + d := dns.nsController.HasSynced() + return a && b && c && d +} + +func (dns *dnsControl) ServiceList() (svcs []*object.Service) { + os := dns.svcLister.List() + for _, o := range os { + s, ok := o.(*object.Service) + if !ok { + continue + } + svcs = append(svcs, s) + } + return svcs +} + +func (dns *dnsControl) EndpointsList() (eps []*object.Endpoints) { + os := dns.epLister.List() + for _, o := range os { + ep, ok := o.(*object.Endpoints) + if !ok { + continue + } + eps = append(eps, ep) + } + return eps +} + +func (dns *dnsControl) PodIndex(ip string) (pods []*object.Pod) { + os, err := dns.podLister.ByIndex(podIPIndex, ip) + if err != nil { + return nil + } + for _, o := range os { + p, ok := o.(*object.Pod) + if !ok { + continue + } + pods = append(pods, p) + } + return pods +} + +func (dns *dnsControl) SvcIndex(idx string) (svcs []*object.Service) { + os, err := dns.svcLister.ByIndex(svcNameNamespaceIndex, idx) + if err != nil { + return nil + } + for _, o := range os { + s, ok := o.(*object.Service) + if !ok { + continue + } + svcs = append(svcs, s) + } + return svcs +} + +func (dns *dnsControl) SvcIndexReverse(ip string) (svcs []*object.Service) { + os, err := dns.svcLister.ByIndex(svcIPIndex, ip) + if err != nil { + return nil + } + + for _, o := range os { + s, ok := o.(*object.Service) + if !ok { + continue + } + svcs = append(svcs, s) + } + return svcs +} + +func (dns *dnsControl) SvcExtIndexReverse(ip string) (svcs []*object.Service) { + os, err := dns.svcLister.ByIndex(svcExtIPIndex, ip) + if err != nil { + return nil + } + + for _, o := range os { + s, ok := o.(*object.Service) + if !ok { + continue + } + svcs = append(svcs, s) + } + return svcs +} + +func (dns *dnsControl) EpIndex(idx string) (ep []*object.Endpoints) { + os, err := dns.epLister.ByIndex(epNameNamespaceIndex, idx) + if err != nil { + return nil + } + for _, o := range os { + e, ok := o.(*object.Endpoints) + if !ok { + continue + } + ep = append(ep, e) + } + return ep +} + +func (dns *dnsControl) EpIndexReverse(ip string) (ep []*object.Endpoints) { + os, err := dns.epLister.ByIndex(epIPIndex, ip) + if err != nil { + return nil + } + for _, o := range os { + e, ok := o.(*object.Endpoints) + if !ok { + continue + } + ep = append(ep, e) + } + return ep +} + +// GetNodeByName return the node by name. If nothing is found an error is +// returned. This query causes a round trip to the k8s API server, so use +// sparingly. Currently, this is only used for Federation. +func (dns *dnsControl) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { + v1node, err := dns.client.CoreV1().Nodes().Get(ctx, name, meta.GetOptions{}) + return v1node, err +} + +// GetNamespaceByName returns the namespace by name. If nothing is found an error is returned. +func (dns *dnsControl) GetNamespaceByName(name string) (*object.Namespace, error) { + o, exists, err := dns.nsLister.GetByKey(name) + if err != nil { + return nil, err + } + if !exists { + return nil, fmt.Errorf("namespace not found") + } + ns, ok := o.(*object.Namespace) + if !ok { + return nil, fmt.Errorf("found key but not namespace") + } + return ns, nil +} + +func (dns *dnsControl) Add(obj interface{}) { dns.updateModified() } +func (dns *dnsControl) Delete(obj interface{}) { dns.updateModified() } +func (dns *dnsControl) Update(oldObj, newObj interface{}) { dns.detectChanges(oldObj, newObj) } + +// detectChanges detects changes in objects, and updates the modified timestamp +func (dns *dnsControl) detectChanges(oldObj, newObj interface{}) { + // If both objects have the same resource version, they are identical. + if newObj != nil && oldObj != nil && (oldObj.(meta.Object).GetResourceVersion() == newObj.(meta.Object).GetResourceVersion()) { + return + } + obj := newObj + if obj == nil { + obj = oldObj + } + switch ob := obj.(type) { + case *object.Service: + imod, emod := serviceModified(oldObj, newObj) + if imod { + dns.updateModified() + } + if emod { + dns.updateExtModified() + } + case *object.Pod: + dns.updateModified() + case *object.Endpoints: + if !endpointsEquivalent(oldObj.(*object.Endpoints), newObj.(*object.Endpoints)) { + dns.updateModified() + } + default: + log.Warningf("Updates for %T not supported.", ob) + } +} + +// subsetsEquivalent checks if two endpoint subsets are significantly equivalent +// I.e. that they have the same ready addresses, host names, ports (including protocol +// and service names for SRV) +func subsetsEquivalent(sa, sb object.EndpointSubset) bool { + if len(sa.Addresses) != len(sb.Addresses) { + return false + } + if len(sa.Ports) != len(sb.Ports) { + return false + } + + // in Addresses and Ports, we should be able to rely on + // these being sorted and able to be compared + // they are supposed to be in a canonical format + for addr, aaddr := range sa.Addresses { + baddr := sb.Addresses[addr] + if aaddr.IP != baddr.IP { + return false + } + if aaddr.Hostname != baddr.Hostname { + return false + } + } + + for port, aport := range sa.Ports { + bport := sb.Ports[port] + if aport.Name != bport.Name { + return false + } + if aport.Port != bport.Port { + return false + } + if aport.Protocol != bport.Protocol { + return false + } + } + return true +} + +// endpointsEquivalent checks if the update to an endpoint is something +// that matters to us or if they are effectively equivalent. +func endpointsEquivalent(a, b *object.Endpoints) bool { + if a == nil || b == nil { + return false + } + + if len(a.Subsets) != len(b.Subsets) { + return false + } + + // we should be able to rely on + // these being sorted and able to be compared + // they are supposed to be in a canonical format + for i, sa := range a.Subsets { + sb := b.Subsets[i] + if !subsetsEquivalent(sa, sb) { + return false + } + } + return true +} + +// serviceModified checks the services passed for changes that result in changes +// to internal and or external records. It returns two booleans, one for internal +// record changes, and a second for external record changes +func serviceModified(oldObj, newObj interface{}) (intSvc, extSvc bool) { + if oldObj != nil && newObj == nil { + // deleted service only modifies external zone records if it had external ips + return true, len(oldObj.(*object.Service).ExternalIPs) > 0 + } + + if oldObj == nil && newObj != nil { + // added service only modifies external zone records if it has external ips + return true, len(newObj.(*object.Service).ExternalIPs) > 0 + } + + newSvc := newObj.(*object.Service) + oldSvc := oldObj.(*object.Service) + + // External IPs are mutable, affecting external zone records + if len(oldSvc.ExternalIPs) != len(newSvc.ExternalIPs) { + extSvc = true + } else { + for i := range oldSvc.ExternalIPs { + if oldSvc.ExternalIPs[i] != newSvc.ExternalIPs[i] { + extSvc = true + break + } + } + } + + // ExternalName is mutable, affecting internal zone records + intSvc = oldSvc.ExternalName != newSvc.ExternalName + + if intSvc && extSvc { + return intSvc, extSvc + } + + // All Port fields are mutable, affecting both internal/external zone records + if len(oldSvc.Ports) != len(newSvc.Ports) { + return true, true + } + for i := range oldSvc.Ports { + if oldSvc.Ports[i].Name != newSvc.Ports[i].Name { + return true, true + } + if oldSvc.Ports[i].Port != newSvc.Ports[i].Port { + return true, true + } + if oldSvc.Ports[i].Protocol != newSvc.Ports[i].Protocol { + return true, true + } + } + + return intSvc, extSvc +} + +func (dns *dnsControl) Modified(external bool) int64 { + if external { + return atomic.LoadInt64(&dns.extModified) + } + return atomic.LoadInt64(&dns.modified) +} + +// updateModified set dns.modified to the current time. +func (dns *dnsControl) updateModified() { + unix := time.Now().Unix() + atomic.StoreInt64(&dns.modified, unix) +} + +// updateExtModified set dns.extModified to the current time. +func (dns *dnsControl) updateExtModified() { + unix := time.Now().Unix() + atomic.StoreInt64(&dns.extModified, unix) +} + +var errObj = errors.New("obj was not of the correct type") diff --git a/plugin/kubernetes/controller_test.go b/plugin/kubernetes/controller_test.go new file mode 100644 index 0000000..c36ab66 --- /dev/null +++ b/plugin/kubernetes/controller_test.go @@ -0,0 +1,303 @@ +package kubernetes + +import ( + "context" + "net" + "strconv" + "testing" + "time" + + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" + discovery "k8s.io/api/discovery/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/fake" +) + +func inc(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +func kubernetesWithFakeClient(ctx context.Context, zone, cidr string, initEndpointsCache bool, svcType string) *Kubernetes { + client := fake.NewSimpleClientset() + dco := dnsControlOpts{ + zones: []string{zone}, + initEndpointsCache: initEndpointsCache, + } + controller := newdnsController(ctx, client, dco) + + // Add resources + _, err := client.CoreV1().Namespaces().Create(ctx, &api.Namespace{ObjectMeta: meta.ObjectMeta{Name: "testns"}}, meta.CreateOptions{}) + if err != nil { + log.Fatal(err) + } + generateSvcs(cidr, svcType, client) + generateEndpointSlices(cidr, client) + k := New([]string{"cluster.local."}) + k.APIConn = controller + return k +} + +func BenchmarkController(b *testing.B) { + ctx := context.Background() + k := kubernetesWithFakeClient(ctx, "cluster.local.", "10.0.0.0/24", true, "all") + + go k.APIConn.Run() + defer k.APIConn.Stop() + for !k.APIConn.HasSynced() { + time.Sleep(time.Millisecond) + } + + rw := &test.ResponseWriter{} + m := new(dns.Msg) + m.SetQuestion("svc1.testns.svc.cluster.local.", dns.TypeA) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + k.ServeDNS(ctx, rw, m) + } +} + +func TestEndpointsDisabled(t *testing.T) { + ctx := context.Background() + k := kubernetesWithFakeClient(ctx, "cluster.local.", "10.0.0.0/30", false, "headless") + k.opts.initEndpointsCache = false + + go k.APIConn.Run() + defer k.APIConn.Stop() + for !k.APIConn.HasSynced() { + time.Sleep(time.Millisecond) + } + + rw := &dnstest.Recorder{ResponseWriter: &test.ResponseWriter{}} + m := new(dns.Msg) + m.SetQuestion("svc2.testns.svc.cluster.local.", dns.TypeA) + k.ServeDNS(ctx, rw, m) + if rw.Msg.Rcode != dns.RcodeNameError { + t.Errorf("Expected NXDOMAIN, got %v", dns.RcodeToString[rw.Msg.Rcode]) + } +} + +func TestEndpointsEnabled(t *testing.T) { + ctx := context.Background() + k := kubernetesWithFakeClient(ctx, "cluster.local.", "10.0.0.0/30", true, "headless") + k.opts.initEndpointsCache = true + + go k.APIConn.Run() + defer k.APIConn.Stop() + for !k.APIConn.HasSynced() { + time.Sleep(time.Millisecond) + } + + rw := &dnstest.Recorder{ResponseWriter: &test.ResponseWriter{}} + m := new(dns.Msg) + m.SetQuestion("svc2.testns.svc.cluster.local.", dns.TypeA) + k.ServeDNS(ctx, rw, m) + if rw.Msg.Rcode != dns.RcodeSuccess { + t.Errorf("Expected SUCCESS, got %v", dns.RcodeToString[rw.Msg.Rcode]) + } +} + +func generateEndpointSlices(cidr string, client kubernetes.Interface) { + // https://groups.google.com/d/msg/golang-nuts/zlcYA4qk-94/TWRFHeXJCcYJ + ip, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + log.Fatal(err) + } + + count := 1 + port := int32(80) + protocol := api.Protocol("tcp") + name := "http" + eps := &discovery.EndpointSlice{ + Ports: []discovery.EndpointPort{ + { + Port: &port, + Protocol: &protocol, + Name: &name, + }, + }, + ObjectMeta: meta.ObjectMeta{ + Namespace: "testns", + }, + } + ctx := context.TODO() + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + hostname := "foo" + strconv.Itoa(count) + eps.Endpoints = []discovery.Endpoint{ + { + Addresses: []string{ip.String()}, + Hostname: &hostname, + }, + } + eps.ObjectMeta.Name = "svc" + strconv.Itoa(count) + eps.ObjectMeta.Labels = map[string]string{discovery.LabelServiceName: eps.ObjectMeta.Name} + _, err := client.DiscoveryV1().EndpointSlices("testns").Create(ctx, eps, meta.CreateOptions{}) + if err != nil { + log.Fatal(err) + } + count++ + } +} + +func generateSvcs(cidr string, svcType string, client kubernetes.Interface) { + ip, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + log.Fatal(err) + } + + count := 1 + switch svcType { + case "clusterip": + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + createClusterIPSvc(count, client, ip) + count++ + } + case "headless": + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + createHeadlessSvc(count, client, ip) + count++ + } + case "external": + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + createExternalSvc(count, client, ip) + count++ + } + default: + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + if count%3 == 0 { + createClusterIPSvc(count, client, ip) + } else if count%3 == 1 { + createHeadlessSvc(count, client, ip) + } else if count%3 == 2 { + createExternalSvc(count, client, ip) + } + count++ + } + } +} + +func createClusterIPSvc(suffix int, client kubernetes.Interface, ip net.IP) { + ctx := context.TODO() + client.CoreV1().Services("testns").Create(ctx, &api.Service{ + ObjectMeta: meta.ObjectMeta{ + Name: "svc" + strconv.Itoa(suffix), + Namespace: "testns", + }, + Spec: api.ServiceSpec{ + ClusterIP: ip.String(), + Ports: []api.ServicePort{{ + Name: "http", + Protocol: "tcp", + Port: 80, + }}, + }, + }, meta.CreateOptions{}) +} + +func createHeadlessSvc(suffix int, client kubernetes.Interface, ip net.IP) { + ctx := context.TODO() + client.CoreV1().Services("testns").Create(ctx, &api.Service{ + ObjectMeta: meta.ObjectMeta{ + Name: "svc" + strconv.Itoa(suffix), + Namespace: "testns", + }, + Spec: api.ServiceSpec{ + ClusterIP: api.ClusterIPNone, + }, + }, meta.CreateOptions{}) +} + +func createExternalSvc(suffix int, client kubernetes.Interface, ip net.IP) { + ctx := context.TODO() + client.CoreV1().Services("testns").Create(ctx, &api.Service{ + ObjectMeta: meta.ObjectMeta{ + Name: "svc" + strconv.Itoa(suffix), + Namespace: "testns", + }, + Spec: api.ServiceSpec{ + ExternalName: "coredns" + strconv.Itoa(suffix) + ".io", + Ports: []api.ServicePort{{ + Name: "http", + Protocol: "tcp", + Port: 80, + }}, + Type: api.ServiceTypeExternalName, + }, + }, meta.CreateOptions{}) +} + +func TestServiceModified(t *testing.T) { + var tests = []struct { + oldSvc interface{} + newSvc interface{} + ichanged bool + echanged bool + }{ + { + oldSvc: nil, + newSvc: &object.Service{}, + ichanged: true, + echanged: false, + }, + { + oldSvc: &object.Service{}, + newSvc: nil, + ichanged: true, + echanged: false, + }, + { + oldSvc: nil, + newSvc: &object.Service{ExternalIPs: []string{"10.0.0.1"}}, + ichanged: true, + echanged: true, + }, + { + oldSvc: &object.Service{ExternalIPs: []string{"10.0.0.1"}}, + newSvc: nil, + ichanged: true, + echanged: true, + }, + { + oldSvc: &object.Service{ExternalIPs: []string{"10.0.0.1"}}, + newSvc: &object.Service{ExternalIPs: []string{"10.0.0.2"}}, + ichanged: false, + echanged: true, + }, + { + oldSvc: &object.Service{ExternalName: "10.0.0.1"}, + newSvc: &object.Service{ExternalName: "10.0.0.2"}, + ichanged: true, + echanged: false, + }, + { + oldSvc: &object.Service{Ports: []api.ServicePort{{Name: "test1"}}}, + newSvc: &object.Service{Ports: []api.ServicePort{{Name: "test2"}}}, + ichanged: true, + echanged: true, + }, + { + oldSvc: &object.Service{Ports: []api.ServicePort{{Name: "test1"}}}, + newSvc: &object.Service{Ports: []api.ServicePort{{Name: "test2"}, {Name: "test3"}}}, + ichanged: true, + echanged: true, + }, + } + + for i, test := range tests { + ichanged, echanged := serviceModified(test.oldSvc, test.newSvc) + if test.ichanged != ichanged || test.echanged != echanged { + t.Errorf("Expected %v, %v for test %v. Got %v, %v", test.ichanged, test.echanged, i, ichanged, echanged) + } + } +} diff --git a/plugin/kubernetes/external.go b/plugin/kubernetes/external.go new file mode 100644 index 0000000..f6705d2 --- /dev/null +++ b/plugin/kubernetes/external.go @@ -0,0 +1,237 @@ +package kubernetes + +import ( + "strings" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Those constants are used to distinguish between records in ExternalServices headless +// return values. +// They are always appendedn to key in a map which is +// either base service key eg. /com/example/namespace/service/endpoint or +// /com/example/namespace/service/_http/_tcp/port.protocol +// this will allow us to distinguish services in implementation of Transfer protocol +// see plugin/k8s_external/transfer.go +const ( + Endpoint = "endpoint" + PortProtocol = "port.protocol" +) + +// External implements the ExternalFunc call from the external plugin. +// It returns any services matching in the services' ExternalIPs and if enabled, headless endpoints.. +func (k *Kubernetes) External(state request.Request, headless bool) ([]msg.Service, int) { + if state.QType() == dns.TypePTR { + ip := dnsutil.ExtractAddressFromReverse(state.Name()) + if ip != "" { + svcs, err := k.ExternalReverse(ip) + if err != nil { + return nil, dns.RcodeNameError + } + return svcs, dns.RcodeSuccess + } + // for invalid reverse names, fall through to determine proper nxdomain/nodata response + } + + base, _ := dnsutil.TrimZone(state.Name(), state.Zone) + + segs := dns.SplitDomainName(base) + last := len(segs) - 1 + if last < 0 { + return nil, dns.RcodeServerFailure + } + // We are dealing with a fairly normal domain name here, but we still need to have the service, + // namespace and if present, endpoint: + // service.namespace.<base> or + // endpoint.service.namespace.<base> + var port, protocol, endpoint string + namespace := segs[last] + if !k.namespaceExposed(namespace) { + return nil, dns.RcodeNameError + } + + last-- + if last < 0 { + return nil, dns.RcodeSuccess + } + + service := segs[last] + last-- + if last == 0 { + endpoint = stripUnderscore(segs[last]) + last-- + } else if last == 1 { + protocol = stripUnderscore(segs[last]) + port = stripUnderscore(segs[last-1]) + last -= 2 + } + + if last != -1 { + // too long + return nil, dns.RcodeNameError + } + + var ( + endpointsList []*object.Endpoints + serviceList []*object.Service + ) + + idx := object.ServiceKey(service, namespace) + serviceList = k.APIConn.SvcIndex(idx) + + services := []msg.Service{} + zonePath := msg.Path(state.Zone, coredns) + rcode := dns.RcodeNameError + + for _, svc := range serviceList { + if namespace != svc.Namespace { + continue + } + if service != svc.Name { + continue + } + + if headless && len(svc.ExternalIPs) == 0 && (svc.Headless() || endpoint != "") { + if endpointsList == nil { + endpointsList = k.APIConn.EpIndex(idx) + } + // Endpoint query or headless service + for _, ep := range endpointsList { + if object.EndpointsKey(svc.Name, svc.Namespace) != ep.Index { + continue + } + + for _, eps := range ep.Subsets { + for _, addr := range eps.Addresses { + if endpoint != "" && !match(endpoint, endpointHostname(addr, k.endpointNameMode)) { + continue + } + + for _, p := range eps.Ports { + if !(matchPortAndProtocol(port, p.Name, protocol, p.Protocol)) { + continue + } + rcode = dns.RcodeSuccess + s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl} + s.Key = strings.Join([]string{zonePath, svc.Namespace, svc.Name, endpointHostname(addr, k.endpointNameMode)}, "/") + + services = append(services, s) + } + } + } + } + continue + } else { + for _, ip := range svc.ExternalIPs { + for _, p := range svc.Ports { + if !(matchPortAndProtocol(port, p.Name, protocol, string(p.Protocol))) { + continue + } + rcode = dns.RcodeSuccess + s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl} + s.Key = strings.Join([]string{zonePath, svc.Namespace, svc.Name}, "/") + + services = append(services, s) + } + } + } + } + if state.QType() == dns.TypePTR { + // if this was a PTR request, return empty service list, but retain rcode for proper nxdomain/nodata response + return nil, rcode + } + return services, rcode +} + +// ExternalAddress returns the external service address(es) for the CoreDNS service. +func (k *Kubernetes) ExternalAddress(state request.Request, headless bool) []dns.RR { + // If CoreDNS is running inside the Kubernetes cluster: k.nsAddrs() will return the external IPs of the services + // targeting the CoreDNS Pod. + // If CoreDNS is running outside of the Kubernetes cluster: k.nsAddrs() will return the first non-loopback IP + // address seen on the local system it is running on. This could be the wrong answer if coredns is using the *bind* + // plugin to bind to a different IP address. + return k.nsAddrs(true, headless, state.Zone) +} + +// ExternalServices returns all services with external IPs and if enabled headless services +func (k *Kubernetes) ExternalServices(zone string, headless bool) (services []msg.Service, headlessServices map[string][]msg.Service) { + zonePath := msg.Path(zone, coredns) + headlessServices = make(map[string][]msg.Service) + for _, svc := range k.APIConn.ServiceList() { + // Endpoints and headless services + if headless && len(svc.ExternalIPs) == 0 && svc.Headless() { + idx := object.ServiceKey(svc.Name, svc.Namespace) + endpointsList := k.APIConn.EpIndex(idx) + + for _, ep := range endpointsList { + for _, eps := range ep.Subsets { + for _, addr := range eps.Addresses { + // we need to have some answers grouped together + // 1. for endpoint requests eg. endpoint-0.service.example.com - will always have one endpoint + // 2. for service requests eg. service.example.com - can have multiple endpoints + // 3. for port.protocol requests eg. _http._tcp.service.example.com - can have multiple endpoints + for _, p := range eps.Ports { + s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl} + baseSvc := strings.Join([]string{zonePath, svc.Namespace, svc.Name}, "/") + s.Key = strings.Join([]string{baseSvc, endpointHostname(addr, k.endpointNameMode)}, "/") + headlessServices[strings.Join([]string{baseSvc, Endpoint}, "/")] = append(headlessServices[strings.Join([]string{baseSvc, Endpoint}, "/")], s) + + // As per spec unnamed ports do not have a srv record + // https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records + if p.Name == "" { + continue + } + s.Host = msg.Domain(s.Key) + s.Key = strings.Join(append([]string{zonePath, svc.Namespace, svc.Name}, strings.ToLower("_"+p.Protocol), strings.ToLower("_"+p.Name)), "/") + headlessServices[strings.Join([]string{s.Key, PortProtocol}, "/")] = append(headlessServices[strings.Join([]string{s.Key, PortProtocol}, "/")], s) + } + } + } + } + continue + } else { + for _, ip := range svc.ExternalIPs { + for _, p := range svc.Ports { + s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl} + s.Key = strings.Join([]string{zonePath, svc.Namespace, svc.Name}, "/") + services = append(services, s) + s.Key = strings.Join(append([]string{zonePath, svc.Namespace, svc.Name}, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+p.Name)), "/") + s.TargetStrip = 2 + services = append(services, s) + } + } + } + } + return services, headlessServices +} + +// ExternalSerial returns the serial of the external zone +func (k *Kubernetes) ExternalSerial(string) uint32 { + return uint32(k.APIConn.Modified(true)) +} + +// ExternalReverse does a reverse lookup for the external IPs +func (k *Kubernetes) ExternalReverse(ip string) ([]msg.Service, error) { + records := k.serviceRecordForExternalIP(ip) + if len(records) == 0 { + return records, errNoItems + } + return records, nil +} + +func (k *Kubernetes) serviceRecordForExternalIP(ip string) []msg.Service { + var svcs []msg.Service + for _, service := range k.APIConn.SvcExtIndexReverse(ip) { + if len(k.Namespaces) > 0 && !k.namespaceExposed(service.Namespace) { + continue + } + domain := strings.Join([]string{service.Name, service.Namespace}, ".") + svcs = append(svcs, msg.Service{Host: domain, TTL: k.ttl}) + } + return svcs +} diff --git a/plugin/kubernetes/external_test.go b/plugin/kubernetes/external_test.go new file mode 100644 index 0000000..474b7be --- /dev/null +++ b/plugin/kubernetes/external_test.go @@ -0,0 +1,198 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" +) + +var extCases = []struct { + Qname string + Qtype uint16 + Msg []msg.Service + Rcode int +}{ + { + Qname: "svc1.testns.example.org.", Rcode: dns.RcodeSuccess, + Msg: []msg.Service{ + {Host: "1.2.3.4", Port: 80, TTL: 5, Key: "/c/org/example/testns/svc1"}, + }, + }, + { + Qname: "svc6.testns.example.org.", Rcode: dns.RcodeSuccess, + Msg: []msg.Service{ + {Host: "1:2::5", Port: 80, TTL: 5, Key: "/c/org/example/testns/svc6"}, + }, + }, + { + Qname: "_http._tcp.svc1.testns.example.com.", Rcode: dns.RcodeSuccess, + Msg: []msg.Service{ + {Host: "1.2.3.4", Port: 80, TTL: 5, Key: "/c/org/example/testns/svc1"}, + }, + }, + { + Qname: "svc0.testns.example.com.", Rcode: dns.RcodeNameError, + }, + { + Qname: "svc0.svc-nons.example.com.", Rcode: dns.RcodeNameError, + }, + { + Qname: "svc-headless.testns.example.com.", Rcode: dns.RcodeSuccess, + Msg: []msg.Service{ + {Host: "1.2.3.4", Port: 80, TTL: 5, Weight: 50, Key: "/c/org/example/testns/svc-headless/endpoint-svc-0"}, + {Host: "1.2.3.5", Port: 80, TTL: 5, Weight: 50, Key: "/c/org/example/testns/svc-headless/endpoint-svc-1"}, + }, + }, + { + Qname: "endpoint-svc-0.svc-headless.testns.example.com.", Rcode: dns.RcodeSuccess, + Msg: []msg.Service{ + {Host: "1.2.3.4", Port: 80, TTL: 5, Weight: 100, Key: "/c/org/example/testns/svc-headless/endpoint-svc-0"}, + }, + }, + { + Qname: "endpoint-1.svc-nons.testns.example.com.", Rcode: dns.RcodeNameError, + }, +} + +func TestExternal(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &external{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + k.Namespaces = map[string]struct{}{"testns": {}} + + for i, tc := range extCases { + state := testRequest(tc.Qname) + + svc, rcode := k.External(state, true) + + if x := tc.Rcode; x != rcode { + t.Errorf("Test %d, expected rcode %d, got %d", i, x, rcode) + } + + if len(svc) != len(tc.Msg) { + t.Errorf("Test %d, expected %d for messages, got %d", i, len(tc.Msg), len(svc)) + } + + for j, s := range svc { + if x := tc.Msg[j].Key; x != s.Key { + t.Errorf("Test %d, expected key %s, got %s", i, x, s.Key) + } + } + } +} + +type external struct{} + +func (external) HasSynced() bool { return true } +func (external) Run() {} +func (external) Stop() error { return nil } +func (external) EpIndexReverse(string) []*object.Endpoints { return nil } +func (external) SvcIndexReverse(string) []*object.Service { return nil } +func (external) SvcExtIndexReverse(string) []*object.Service { return nil } +func (external) Modified(bool) int64 { return 0 } +func (external) EpIndex(s string) []*object.Endpoints { + return epIndexExternal[s] +} +func (external) EndpointsList() []*object.Endpoints { + var eps []*object.Endpoints + for _, ep := range epIndexExternal { + eps = append(eps, ep...) + } + return eps +} +func (external) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { return nil, nil } +func (external) SvcIndex(s string) []*object.Service { return svcIndexExternal[s] } +func (external) PodIndex(string) []*object.Pod { return nil } + +func (external) GetNamespaceByName(name string) (*object.Namespace, error) { + return &object.Namespace{ + Name: name, + }, nil +} + +var epIndexExternal = map[string][]*object.Endpoints{ + "svc-headless.testns": { + { + Name: "svc-headless", + Namespace: "testns", + Index: "svc-headless.testns", + Subsets: []object.EndpointSubset{ + { + Ports: []object.EndpointPort{ + { + Port: 80, + Name: "http", + Protocol: "TCP", + }, + }, + Addresses: []object.EndpointAddress{ + { + IP: "1.2.3.4", + Hostname: "endpoint-svc-0", + NodeName: "test-node", + TargetRefName: "endpoint-svc-0", + }, + { + IP: "1.2.3.5", + Hostname: "endpoint-svc-1", + NodeName: "test-node", + TargetRefName: "endpoint-svc-1", + }, + }, + }, + }, + }, + }, +} + +var svcIndexExternal = map[string][]*object.Service{ + "svc1.testns": { + { + Name: "svc1", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.1"}, + ExternalIPs: []string{"1.2.3.4"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, + "svc6.testns": { + { + Name: "svc6", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.3"}, + ExternalIPs: []string{"1:2::5"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, + "svc-headless.testns": { + { + Name: "svc-headless", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{api.ClusterIPNone}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + }, +} + +func (external) ServiceList() []*object.Service { + var svcs []*object.Service + for _, svc := range svcIndexExternal { + svcs = append(svcs, svc...) + } + return svcs +} + +func testRequest(name string) request.Request { + m := new(dns.Msg).SetQuestion(name, dns.TypeA) + return request.Request{W: &test.ResponseWriter{}, Req: m, Zone: "example.org."} +} diff --git a/plugin/kubernetes/handler.go b/plugin/kubernetes/handler.go new file mode 100644 index 0000000..d673a7a --- /dev/null +++ b/plugin/kubernetes/handler.go @@ -0,0 +1,94 @@ +package kubernetes + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServeDNS implements the plugin.Handler interface. +func (k Kubernetes) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + qname := state.QName() + zone := plugin.Zones(k.Zones).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(k.Name(), k.Next, ctx, w, r) + } + zone = qname[len(qname)-len(zone):] // maintain case of original query + state.Zone = zone + + var ( + records []dns.RR + extra []dns.RR + truncated bool + err error + ) + + switch state.QType() { + case dns.TypeA: + records, truncated, err = plugin.A(ctx, &k, zone, state, nil, plugin.Options{}) + case dns.TypeAAAA: + records, truncated, err = plugin.AAAA(ctx, &k, zone, state, nil, plugin.Options{}) + case dns.TypeTXT: + records, truncated, err = plugin.TXT(ctx, &k, zone, state, nil, plugin.Options{}) + case dns.TypeCNAME: + records, err = plugin.CNAME(ctx, &k, zone, state, plugin.Options{}) + case dns.TypePTR: + records, err = plugin.PTR(ctx, &k, zone, state, plugin.Options{}) + case dns.TypeMX: + records, extra, err = plugin.MX(ctx, &k, zone, state, plugin.Options{}) + case dns.TypeSRV: + records, extra, err = plugin.SRV(ctx, &k, zone, state, plugin.Options{}) + case dns.TypeSOA: + if qname == zone { + records, err = plugin.SOA(ctx, &k, zone, state, plugin.Options{}) + } + case dns.TypeAXFR, dns.TypeIXFR: + return dns.RcodeRefused, nil + case dns.TypeNS: + if state.Name() == zone { + records, extra, err = plugin.NS(ctx, &k, zone, state, plugin.Options{}) + break + } + fallthrough + default: + // Do a fake A lookup, so we can distinguish between NODATA and NXDOMAIN + fake := state.NewWithQuestion(state.QName(), dns.TypeA) + fake.Zone = state.Zone + _, _, err = plugin.A(ctx, &k, zone, fake, nil, plugin.Options{}) + } + + if k.IsNameError(err) { + if k.Fall.Through(state.Name()) { + return plugin.NextOrFailure(k.Name(), k.Next, ctx, w, r) + } + if !k.APIConn.HasSynced() { + // If we haven't synchronized with the kubernetes cluster, return server failure + return plugin.BackendError(ctx, &k, zone, dns.RcodeServerFailure, state, nil /* err */, plugin.Options{}) + } + return plugin.BackendError(ctx, &k, zone, dns.RcodeNameError, state, nil /* err */, plugin.Options{}) + } + if err != nil { + return dns.RcodeServerFailure, err + } + + if len(records) == 0 { + return plugin.BackendError(ctx, &k, zone, dns.RcodeSuccess, state, nil, plugin.Options{}) + } + + m := new(dns.Msg) + m.SetReply(r) + m.Truncated = truncated + m.Authoritative = true + m.Answer = append(m.Answer, records...) + m.Extra = append(m.Extra, extra...) + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +// Name implements the Handler interface. +func (k Kubernetes) Name() string { return "kubernetes" } diff --git a/plugin/kubernetes/handler_case_test.go b/plugin/kubernetes/handler_case_test.go new file mode 100644 index 0000000..c3f90f1 --- /dev/null +++ b/plugin/kubernetes/handler_case_test.go @@ -0,0 +1,80 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var dnsPreserveCaseCases = []test.Case{ + // Negative response + { + Qname: "not-a-service.testns.svc.ClUsTeR.lOcAl.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("ClUsTeR.lOcAl. 5 IN SOA ns.dns.ClUsTeR.lOcAl. hostmaster.ClUsTeR.lOcAl. 1499347823 7200 1800 86400 5"), + }, + }, + // A Service + { + Qname: "SvC1.TeStNs.SvC.cLuStEr.LoCaL.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("SvC1.TeStNs.SvC.cLuStEr.LoCaL. 5 IN A 10.0.0.1"), + }, + }, + // SRV Service + { + Qname: "_HtTp._TcP.sVc1.TeStNs.SvC.cLuStEr.LoCaL.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_HtTp._TcP.sVc1.TeStNs.SvC.cLuStEr.LoCaL. 5 IN SRV 0 100 80 svc1.testns.svc.cLuStEr.LoCaL."), + }, + Extra: []dns.RR{ + test.A("svc1.testns.svc.cLuStEr.LoCaL. 5 IN A 10.0.0.1"), + }, + }, + { + Qname: "Cluster.local.", Qtype: dns.TypeSOA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SOA("Cluster.local. 5 IN SOA ns.dns.Cluster.local. hostmaster.Cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, +} + +func TestPreserveCase(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.opts.ignoreEmptyService = true + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + ctx := context.TODO() + + for i, tc := range dnsPreserveCaseCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/kubernetes/handler_ignore_emptyservice_test.go b/plugin/kubernetes/handler_ignore_emptyservice_test.go new file mode 100644 index 0000000..7af77fe --- /dev/null +++ b/plugin/kubernetes/handler_ignore_emptyservice_test.go @@ -0,0 +1,67 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var dnsEmptyServiceTestCases = []test.Case{ + // A Service + { + Qname: "svcempty.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + // CNAME to external + { + Qname: "external.testns.svc.cluster.local.", Qtype: dns.TypeCNAME, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.CNAME("external.testns.svc.cluster.local. 5 IN CNAME ext.interwebs.test."), + }, + }, +} + +func TestServeDNSEmptyService(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.opts.ignoreEmptyService = true + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + ctx := context.TODO() + + for i, tc := range dnsEmptyServiceTestCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + // Before sorting, make sure that CNAMES do not appear after their target records + if err := test.CNAMEOrder(resp); err != nil { + t.Error(err) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/kubernetes/handler_pod_disabled_test.go b/plugin/kubernetes/handler_pod_disabled_test.go new file mode 100644 index 0000000..be7e7a3 --- /dev/null +++ b/plugin/kubernetes/handler_pod_disabled_test.go @@ -0,0 +1,60 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var podModeDisabledCases = []test.Case{ + { + Qname: "10-240-0-1.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "172-0-0-2.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, +} + +func TestServeDNSModeDisabled(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + k.podMode = podModeDisabled + ctx := context.TODO() + + for i, tc := range podModeDisabledCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d got unexpected error %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/kubernetes/handler_pod_insecure_test.go b/plugin/kubernetes/handler_pod_insecure_test.go new file mode 100644 index 0000000..b01d53f --- /dev/null +++ b/plugin/kubernetes/handler_pod_insecure_test.go @@ -0,0 +1,95 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var podModeInsecureCases = []test.Case{ + { + Qname: "10-240-0-1.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("10-240-0-1.podns.pod.cluster.local. 5 IN A 10.240.0.1"), + }, + }, + { + Qname: "172-0-0-2.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("172-0-0-2.podns.pod.cluster.local. 5 IN A 172.0.0.2"), + }, + }, + { + Qname: "blah.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1515173576 7200 1800 86400 30"), + }, + }, + { + Qname: "blah.podns.pod.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1515173576 7200 1800 86400 30"), + }, + }, + { + Qname: "blah.podns.pod.cluster.local.", Qtype: dns.TypeHINFO, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1515173576 7200 1800 86400 30"), + }, + }, + { + Qname: "blah.pod-nons.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1515173576 7200 1800 86400 30"), + }, + }, + { + Qname: "podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1515173576 7200 1800 86400 30"), + }, + }, +} + +func TestServeDNSModeInsecure(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + ctx := context.TODO() + k.podMode = podModeInsecure + + for i, tc := range podModeInsecureCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/kubernetes/handler_pod_verified_test.go b/plugin/kubernetes/handler_pod_verified_test.go new file mode 100644 index 0000000..c8b09c4 --- /dev/null +++ b/plugin/kubernetes/handler_pod_verified_test.go @@ -0,0 +1,81 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var podModeVerifiedCases = []test.Case{ + { + Qname: "10-240-0-1.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("10-240-0-1.podns.pod.cluster.local. 5 IN A 10.240.0.1"), + }, + }, + { + Qname: "podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "svcns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "pod-nons.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "172-0-0-2.podns.pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, +} + +func TestServeDNSModeVerified(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + ctx := context.TODO() + k.podMode = podModeVerified + + for i, tc := range podModeVerifiedCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/kubernetes/handler_test.go b/plugin/kubernetes/handler_test.go new file mode 100644 index 0000000..405dc73 --- /dev/null +++ b/plugin/kubernetes/handler_test.go @@ -0,0 +1,848 @@ +package kubernetes + +import ( + "context" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type kubeTestCase struct { + Upstream Upstreamer + Truncated bool + test.Case +} + +var dnsTestCases = []kubeTestCase{ + // A Service + {Case: test.Case{ + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1"), + }, + }}, + {Case: test.Case{ + Qname: "svcempty.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svcempty.testns.svc.cluster.local. 5 IN A 10.0.0.1"), + }, + }}, + {Case: test.Case{ + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{test.SRV("svc1.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc1.testns.svc.cluster.local.")}, + Extra: []dns.RR{test.A("svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1")}, + }}, + {Case: test.Case{ + Qname: "svcempty.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{test.SRV("svcempty.testns.svc.cluster.local. 5 IN SRV 0 100 80 svcempty.testns.svc.cluster.local.")}, + Extra: []dns.RR{test.A("svcempty.testns.svc.cluster.local. 5 IN A 10.0.0.1")}, + }}, + {Case: test.Case{ + Qname: "svc6.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{test.SRV("svc6.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc6.testns.svc.cluster.local.")}, + Extra: []dns.RR{test.AAAA("svc6.testns.svc.cluster.local. 5 IN AAAA 1234:abcd::1")}, + }}, + // SRV Service + {Case: test.Case{ + + Qname: "_http._tcp.svc1.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc1.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc1.testns.svc.cluster.local."), + }, + Extra: []dns.RR{ + test.A("svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1"), + }, + }}, + {Case: test.Case{ + + Qname: "_http._tcp.svcempty.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svcempty.testns.svc.cluster.local. 5 IN SRV 0 100 80 svcempty.testns.svc.cluster.local."), + }, + Extra: []dns.RR{ + test.A("svcempty.testns.svc.cluster.local. 5 IN A 10.0.0.1"), + }, + }}, + // A Service (Headless) + {Case: test.Case{ + Qname: "hdls1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.2"), + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.3"), + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.4"), + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.5"), + }, + }}, + // A Service (Headless and Portless) + {Case: test.Case{ + Qname: "hdlsprtls.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("hdlsprtls.testns.svc.cluster.local. 5 IN A 172.0.0.20"), + }, + }}, + // An Endpoint with no port + {Case: test.Case{ + Qname: "172-0-0-20.hdlsprtls.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("172-0-0-20.hdlsprtls.testns.svc.cluster.local. 5 IN A 172.0.0.20"), + }, + }}, + // An Endpoint ip + {Case: test.Case{ + Qname: "172-0-0-2.hdls1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("172-0-0-2.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.2"), + }, + }}, + // A Endpoint ip + {Case: test.Case{ + Qname: "172-0-0-3.hdls1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("172-0-0-3.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.3"), + }, + }}, + // An Endpoint by name + {Case: test.Case{ + Qname: "dup-name.hdls1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("dup-name.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.4"), + test.A("dup-name.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.5"), + }, + }}, + // SRV Service (Headless) + {Case: test.Case{ + Qname: "_http._tcp.hdls1.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 172-0-0-2.hdls1.testns.svc.cluster.local."), + test.SRV("_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 172-0-0-3.hdls1.testns.svc.cluster.local."), + test.SRV("_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 5678-abcd--1.hdls1.testns.svc.cluster.local."), + test.SRV("_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 5678-abcd--2.hdls1.testns.svc.cluster.local."), + test.SRV("_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 dup-name.hdls1.testns.svc.cluster.local."), + }, + Extra: []dns.RR{ + test.A("172-0-0-2.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.2"), + test.A("172-0-0-3.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.3"), + test.AAAA("5678-abcd--1.hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::1"), + test.AAAA("5678-abcd--2.hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::2"), + test.A("dup-name.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.4"), + test.A("dup-name.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.5"), + }, + }}, + {Case: test.Case{ // An A record query for an existing headless service should return a record for each of its ipv4 endpoints + Qname: "hdls1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.2"), + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.3"), + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.4"), + test.A("hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.5"), + }, + }}, + // AAAA + {Case: test.Case{ + Qname: "5678-abcd--2.hdls1.testns.svc.cluster.local", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{test.AAAA("5678-abcd--2.hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::2")}, + }}, + // CNAME External + {Case: test.Case{ + Qname: "external.testns.svc.cluster.local.", Qtype: dns.TypeCNAME, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.CNAME("external.testns.svc.cluster.local. 5 IN CNAME ext.interwebs.test."), + }, + }}, + // CNAME External Truncated Lookup + { + Case: test.Case{ + Qname: "external.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("ext.interwebs.test. 5 IN A 1.2.3.4"), + test.CNAME("external.testns.svc.cluster.local. 5 IN CNAME ext.interwebs.test."), + }, + }, + Upstream: &Upstub{ + Truncated: true, + Qclass: dns.ClassINET, + Case: test.Case{ + Qname: "external.testns.svc.cluster.local.", + Qtype: dns.TypeA, + Answer: []dns.RR{ + test.A("ext.interwebs.test. 5 IN A 1.2.3.4"), + test.CNAME("external.testns.svc.cluster.local. 5 IN CNAME ext.interwebs.test."), + }, + }, + }, + Truncated: true, + }, + // CNAME External To Internal Service + {Case: test.Case{ + Qname: "external-to-service.testns.svc.cluster.local", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.CNAME("external-to-service.testns.svc.cluster.local. 5 IN CNAME svc1.testns.svc.cluster.local."), + test.A("svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1"), + }, + }}, + // AAAA Service (with an existing A record, but no AAAA record) + {Case: test.Case{ + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // AAAA Service (non-existing service) + {Case: test.Case{ + Qname: "svc0.testns.svc.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A Service (non-existing service) + {Case: test.Case{ + Qname: "svc0.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A Service (non-existing namespace) + {Case: test.Case{ + Qname: "svc0.svc-nons.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // TXT Schema + {Case: test.Case{ + Qname: "dns-version.cluster.local.", Qtype: dns.TypeTXT, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.TXT("dns-version.cluster.local 28800 IN TXT 1.1.0"), + }, + }}, + // A TXT record does not exist but another record for the same FQDN does + {Case: test.Case{ + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypeTXT, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A TXT record does not exist and neither does another record for the same FQDN + {Case: test.Case{ + Qname: "svc0.svc-nons.svc.cluster.local.", Qtype: dns.TypeTXT, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A Service (Headless) does not exist + {Case: test.Case{ + Qname: "bogusendpoint.hdls1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A Service does not exist + {Case: test.Case{ + Qname: "bogusendpoint.svc0.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // AAAA Service + {Case: test.Case{ + Qname: "svc6.testns.svc.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.AAAA("svc6.testns.svc.cluster.local. 5 IN AAAA 1234:abcd::1"), + }, + }}, + // SRV + {Case: test.Case{ + Qname: "_http._tcp.svc6.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_http._tcp.svc6.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc6.testns.svc.cluster.local."), + }, + Extra: []dns.RR{ + test.AAAA("svc6.testns.svc.cluster.local. 5 IN AAAA 1234:abcd::1"), + }, + }}, + // AAAA Service (Headless) + {Case: test.Case{ + Qname: "hdls1.testns.svc.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.AAAA("hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::1"), + test.AAAA("hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::2"), + }, + }}, + // AAAA Endpoint + {Case: test.Case{ + Qname: "5678-abcd--1.hdls1.testns.svc.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.AAAA("5678-abcd--1.hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::1"), + }, + }}, + + {Case: test.Case{ + Qname: "svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + {Case: test.Case{ + Qname: "pod.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + {Case: test.Case{ + Qname: "testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // NS query for qname != zone (existing domain) + {Case: test.Case{ + Qname: "svc.cluster.local.", Qtype: dns.TypeNS, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // NS query for qname != zone (existing domain) + {Case: test.Case{ + Qname: "testns.svc.cluster.local.", Qtype: dns.TypeNS, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // NS query for qname != zone (non existing domain) + {Case: test.Case{ + Qname: "foo.cluster.local.", Qtype: dns.TypeNS, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // NS query for qname != zone (non existing domain) + {Case: test.Case{ + Qname: "foo.svc.cluster.local.", Qtype: dns.TypeNS, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // Dual Stack ClusterIP Services + {Case: test.Case{ + Qname: "svc-dual-stack.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svc-dual-stack.testns.svc.cluster.local. 5 IN A 10.0.0.3"), + }, + }}, + {Case: test.Case{ + Qname: "svc-dual-stack.testns.svc.cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.AAAA("svc-dual-stack.testns.svc.cluster.local. 5 IN AAAA 10::3"), + }, + }}, + {Case: test.Case{ + Qname: "svc-dual-stack.testns.svc.cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{test.SRV("svc-dual-stack.testns.svc.cluster.local. 5 IN SRV 0 50 80 svc-dual-stack.testns.svc.cluster.local.")}, + Extra: []dns.RR{ + test.A("svc-dual-stack.testns.svc.cluster.local. 5 IN A 10.0.0.3"), + test.AAAA("svc-dual-stack.testns.svc.cluster.local. 5 IN AAAA 10::3"), + }, + }}, + {Case: test.Case{ + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypeSOA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A query for a subdomain of an external service should not resolve to the external service + {Case: test.Case{ + Qname: "endpoint.external.testns.svc.cluster.local.", Qtype: dns.TypeCNAME, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, + // A query for a subdomain of a subdomain of an external service should not resolve to the external service + {Case: test.Case{ + Qname: "subdomain.subdomain.external.testns.svc.cluster.local.", Qtype: dns.TypeCNAME, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }}, +} + +func TestServeDNS(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + k.Namespaces = map[string]struct{}{"testns": {}} + ctx := context.TODO() + + for i, tc := range dnsTestCases { + k.Upstream = tc.Upstream + + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + if tc.Truncated != resp.Truncated { + t.Errorf("Expected truncation %t, got truncation %t", tc.Truncated, resp.Truncated) + } + + // Before sorting, make sure that CNAMES do not appear after their target records + if err := test.CNAMEOrder(resp); err != nil { + t.Errorf("Test %d, %v", i, err) + } + + if err := test.SortAndCheck(resp, tc.Case); err != nil { + t.Errorf("Test %d, %v", i, err) + } + } +} + +var nsTestCases = []test.Case{ + // A Service for an "exposed" namespace that "does exist" + { + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.A("svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1"), + }, + }, + // A service for an "exposed" namespace that "doesn't exist" + { + Qname: "svc1.nsnoexist.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 300 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1551484803 7200 1800 86400 30"), + }, + }, +} + +func TestServeNamespaceDNS(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + // if no namespaces are explicitly exposed, then they are all implicitly exposed + k.Namespaces = map[string]struct{}{} + ctx := context.TODO() + + for i, tc := range nsTestCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + // Before sorting, make sure that CNAMES do not appear after their target records + test.CNAMEOrder(resp) + + test.SortAndCheck(resp, tc) + } +} + +var notSyncedTestCases = []test.Case{ + { + // We should get ServerFailure instead of NameError for missing records when we kubernetes hasn't synced + Qname: "svc0.testns.svc.cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeServerFailure, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, +} + +func TestNotSyncedServeDNS(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{ + notSynced: true, + } + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + k.Namespaces = map[string]struct{}{"testns": {}} + ctx := context.TODO() + + for i, tc := range notSyncedTestCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error for %q", i, r.Question[0].Name) + } + + if err := test.CNAMEOrder(resp); err != nil { + t.Error(err) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} + +type APIConnServeTest struct { + notSynced bool +} + +func (a APIConnServeTest) HasSynced() bool { return !a.notSynced } +func (APIConnServeTest) Run() {} +func (APIConnServeTest) Stop() error { return nil } +func (APIConnServeTest) EpIndexReverse(string) []*object.Endpoints { return nil } +func (APIConnServeTest) SvcIndexReverse(string) []*object.Service { return nil } +func (APIConnServeTest) SvcExtIndexReverse(string) []*object.Service { return nil } +func (APIConnServeTest) Modified(bool) int64 { return int64(3) } + +func (APIConnServeTest) PodIndex(ip string) []*object.Pod { + if ip != "10.240.0.1" { + return []*object.Pod{} + } + a := []*object.Pod{ + {Namespace: "podns", Name: "foo", PodIP: "10.240.0.1"}, // Remote IP set in test.ResponseWriter + } + return a +} + +var svcIndex = map[string][]*object.Service{ + "kubedns.kube-system": { + { + Name: "kubedns", + Namespace: "kube-system", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.10"}, + Ports: []api.ServicePort{ + {Name: "dns", Protocol: "udp", Port: 53}, + }, + }, + }, + "svc1.testns": { + { + Name: "svc1", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.1"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, + "svcempty.testns": { + { + Name: "svcempty", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.1"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, + "svc6.testns": { + { + Name: "svc6", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"1234:abcd::1"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, + "hdls1.testns": { + { + Name: "hdls1", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{api.ClusterIPNone}, + }, + }, + "external.testns": { + { + Name: "external", + Namespace: "testns", + ExternalName: "ext.interwebs.test", + Type: api.ServiceTypeExternalName, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, + "external-to-service.testns": { + { + Name: "external-to-service", + Namespace: "testns", + ExternalName: "svc1.testns.svc.cluster.local.", + Type: api.ServiceTypeExternalName, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, + "hdlsprtls.testns": { + { + Name: "hdlsprtls", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{api.ClusterIPNone}, + }, + }, + "svc1.unexposedns": { + { + Name: "svc1", + Namespace: "unexposedns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.2"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, + "svc-dual-stack.testns": { + { + Name: "svc-dual-stack", + Namespace: "testns", + Type: api.ServiceTypeClusterIP, + ClusterIPs: []string{"10.0.0.3", "10::3"}, Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + }, +} + +func (APIConnServeTest) SvcIndex(s string) []*object.Service { return svcIndex[s] } + +func (APIConnServeTest) ServiceList() []*object.Service { + var svcs []*object.Service + for _, svc := range svcIndex { + svcs = append(svcs, svc...) + } + return svcs +} + +var epsIndex = map[string][]*object.Endpoints{ + "kubedns.kube-system": {{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.100"}, + }, + Ports: []object.EndpointPort{ + {Port: 53, Protocol: "udp", Name: "dns"}, + }, + }, + }, + Name: "kubedns", + Namespace: "kube-system", + Index: object.EndpointsKey("kubedns", "kube-system"), + }}, + "svc1.testns": {{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.1", Hostname: "ep1a"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("svc1", "testns"), + }}, + "svcempty.testns": {{ + Subsets: []object.EndpointSubset{ + { + Addresses: nil, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svcempty-slice1", + Namespace: "testns", + Index: object.EndpointsKey("svcempty", "testns"), + }}, + "hdls1.testns": {{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.2"}, + {IP: "172.0.0.3"}, + {IP: "172.0.0.4", Hostname: "dup-name"}, + {IP: "172.0.0.5", Hostname: "dup-name"}, + {IP: "5678:abcd::1"}, + {IP: "5678:abcd::2"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "hdls1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("hdls1", "testns"), + }}, + "hdlsprtls.testns": {{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.20"}, + }, + Ports: []object.EndpointPort{{Port: -1}}, + }, + }, + Name: "hdlsprtls-slice1", + Namespace: "testns", + Index: object.EndpointsKey("hdlsprtls", "testns"), + }}, +} + +func (APIConnServeTest) EpIndex(s string) []*object.Endpoints { + return epsIndex[s] +} + +func (APIConnServeTest) EndpointsList() []*object.Endpoints { + var eps []*object.Endpoints + for _, ep := range epsIndex { + eps = append(eps, ep...) + } + return eps +} + +func (APIConnServeTest) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { + return &api.Node{ + ObjectMeta: meta.ObjectMeta{ + Name: "test.node.foo.bar", + }, + }, nil +} + +func (APIConnServeTest) GetNamespaceByName(name string) (*object.Namespace, error) { + if name == "pod-nons" { // handler_pod_verified_test.go uses this for non-existent namespace. + return nil, fmt.Errorf("namespace not found") + } + if name == "nsnoexist" { + return nil, fmt.Errorf("namespace not found") + } + return &object.Namespace{ + Name: name, + }, nil +} + +// Upstub implements an Upstreamer that returns a set response for test purposes +type Upstub struct { + test.Case + Truncated bool + Qclass uint16 +} + +// Lookup returns a set response +func (t *Upstub) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + var answer []dns.RR + // if query type is not CNAME, remove any CNAME with same name as qname from the answer + if t.Qtype != dns.TypeCNAME { + for _, a := range t.Answer { + if c, ok := a.(*dns.CNAME); ok && c.Header().Name == t.Qname { + continue + } + answer = append(answer, a) + } + } else { + answer = t.Answer + } + + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Truncated: t.Truncated, + Rcode: t.Rcode, + }, + Question: []dns.Question{{Name: t.Qname, Qtype: t.Qtype, Qclass: t.Qclass}}, + Answer: answer, + Extra: t.Extra, + Ns: t.Ns, + }, nil +} diff --git a/plugin/kubernetes/informer_test.go b/plugin/kubernetes/informer_test.go new file mode 100644 index 0000000..ee5186a --- /dev/null +++ b/plugin/kubernetes/informer_test.go @@ -0,0 +1,120 @@ +package kubernetes + +import ( + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/kubernetes/object" + + api "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/cache" +) + +func TestDefaultProcessor(t *testing.T) { + pbuild := object.DefaultProcessor(object.ToService, nil) + reh := cache.ResourceEventHandlerFuncs{} + idx := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, cache.Indexers{}) + processor := pbuild(idx, reh) + testProcessor(t, processor, idx) +} + +func testProcessor(t *testing.T, processor cache.ProcessFunc, idx cache.Indexer) { + obj := &api.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "service1", Namespace: "test1"}, + Spec: api.ServiceSpec{ + ClusterIP: "1.2.3.4", + ClusterIPs: []string{"1.2.3.4"}, + Ports: []api.ServicePort{{Port: 80}}, + }, + } + obj2 := &api.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "service2", Namespace: "test1"}, + Spec: api.ServiceSpec{ + ClusterIP: "5.6.7.8", + ClusterIPs: []string{"5.6.7.8"}, + Ports: []api.ServicePort{{Port: 80}}, + }, + } + + // Add the objects + err := processor(cache.Deltas{ + {Type: cache.Added, Object: obj.DeepCopy()}, + {Type: cache.Added, Object: obj2.DeepCopy()}, + }, false) + if err != nil { + t.Fatalf("add failed: %v", err) + } + got, exists, err := idx.Get(obj) + if err != nil { + t.Fatalf("get added object failed: %v", err) + } + if !exists { + t.Fatal("added object not found in index") + } + svc, ok := got.(*object.Service) + if !ok { + t.Fatal("object in index was incorrect type") + } + if fmt.Sprintf("%v", svc.ClusterIPs) != fmt.Sprintf("%v", obj.Spec.ClusterIPs) { + t.Fatalf("expected '%v', got '%v'", obj.Spec.ClusterIPs, svc.ClusterIPs) + } + + // Update an object + obj.Spec.ClusterIP = "1.2.3.5" + err = processor(cache.Deltas{{ + Type: cache.Updated, + Object: obj.DeepCopy(), + }}, false) + if err != nil { + t.Fatalf("update failed: %v", err) + } + got, exists, err = idx.Get(obj) + if err != nil { + t.Fatalf("get updated object failed: %v", err) + } + if !exists { + t.Fatal("updated object not found in index") + } + svc, ok = got.(*object.Service) + if !ok { + t.Fatal("object in index was incorrect type") + } + if fmt.Sprintf("%v", svc.ClusterIPs) != fmt.Sprintf("%v", obj.Spec.ClusterIPs) { + t.Fatalf("expected '%v', got '%v'", obj.Spec.ClusterIPs, svc.ClusterIPs) + } + + // Delete an object + err = processor(cache.Deltas{{ + Type: cache.Deleted, + Object: obj2.DeepCopy(), + }}, false) + if err != nil { + t.Fatalf("delete test failed: %v", err) + } + _, exists, err = idx.Get(obj2) + if err != nil { + t.Fatalf("get deleted object failed: %v", err) + } + if exists { + t.Fatal("deleted object found in index") + } + + // Delete an object via tombstone + key, _ := cache.MetaNamespaceKeyFunc(obj) + tombstone := cache.DeletedFinalStateUnknown{Key: key, Obj: svc} + err = processor(cache.Deltas{{ + Type: cache.Deleted, + Object: tombstone, + }}, false) + if err != nil { + t.Fatalf("tombstone delete test failed: %v", err) + } + _, exists, err = idx.Get(svc) + if err != nil { + t.Fatalf("get tombstone deleted object failed: %v", err) + } + if exists { + t.Fatal("tombstone deleted object found in index") + } +} diff --git a/plugin/kubernetes/kubernetes.go b/plugin/kubernetes/kubernetes.go new file mode 100644 index 0000000..cea23d8 --- /dev/null +++ b/plugin/kubernetes/kubernetes.go @@ -0,0 +1,533 @@ +// Package kubernetes provides the kubernetes backend. +package kubernetes + +import ( + "context" + "errors" + "fmt" + "net" + "strings" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// Kubernetes implements a plugin that connects to a Kubernetes cluster. +type Kubernetes struct { + Next plugin.Handler + Zones []string + Upstream Upstreamer + APIServerList []string + APICertAuth string + APIClientCert string + APIClientKey string + ClientConfig clientcmd.ClientConfig + APIConn dnsController + Namespaces map[string]struct{} + podMode string + endpointNameMode bool + Fall fall.F + ttl uint32 + opts dnsControlOpts + primaryZoneIndex int + localIPs []net.IP + autoPathSearch []string // Local search path from /etc/resolv.conf. Needed for autopath. +} + +// Upstreamer is used to resolve CNAME or other external targets +type Upstreamer interface { + Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) +} + +// New returns a initialized Kubernetes. It default interfaceAddrFunc to return 127.0.0.1. All other +// values default to their zero value, primaryZoneIndex will thus point to the first zone. +func New(zones []string) *Kubernetes { + k := new(Kubernetes) + k.Zones = zones + k.Namespaces = make(map[string]struct{}) + k.podMode = podModeDisabled + k.ttl = defaultTTL + + return k +} + +const ( + // podModeDisabled is the default value where pod requests are ignored + podModeDisabled = "disabled" + // podModeVerified is where Pod requests are answered only if they exist + podModeVerified = "verified" + // podModeInsecure is where pod requests are answered without verifying they exist + podModeInsecure = "insecure" + // DNSSchemaVersion is the schema version: https://github.com/kubernetes/dns/blob/master/docs/specification.md + DNSSchemaVersion = "1.1.0" + // Svc is the DNS schema for kubernetes services + Svc = "svc" + // Pod is the DNS schema for kubernetes pods + Pod = "pod" + // defaultTTL to apply to all answers. + defaultTTL = 5 +) + +var ( + errNoItems = errors.New("no items found") + errNsNotExposed = errors.New("namespace is not exposed") + errInvalidRequest = errors.New("invalid query name") +) + +// Services implements the ServiceBackend interface. +func (k *Kubernetes) Services(ctx context.Context, state request.Request, exact bool, opt plugin.Options) (svcs []msg.Service, err error) { + // We're looking again at types, which we've already done in ServeDNS, but there are some types k8s just can't answer. + switch state.QType() { + case dns.TypeTXT: + // 1 label + zone, label must be "dns-version". + t, _ := dnsutil.TrimZone(state.Name(), state.Zone) + + // Hard code the only valid TXT - "dns-version.<zone>" + segs := dns.SplitDomainName(t) + if len(segs) == 1 && segs[0] == "dns-version" { + svc := msg.Service{Text: DNSSchemaVersion, TTL: 28800, Key: msg.Path(state.QName(), coredns)} + return []msg.Service{svc}, nil + } + + // Check if we have an existing record for this query of another type + services, _ := k.Records(ctx, state, false) + + if len(services) > 0 { + // If so we return an empty NOERROR + return nil, nil + } + + // Return NXDOMAIN for no match + return nil, errNoItems + + case dns.TypeNS: + // We can only get here if the qname equals the zone, see ServeDNS in handler.go. + nss := k.nsAddrs(false, false, state.Zone) + var svcs []msg.Service + for _, ns := range nss { + if ns.Header().Rrtype == dns.TypeA { + svcs = append(svcs, msg.Service{Host: ns.(*dns.A).A.String(), Key: msg.Path(ns.Header().Name, coredns), TTL: k.ttl}) + continue + } + if ns.Header().Rrtype == dns.TypeAAAA { + svcs = append(svcs, msg.Service{Host: ns.(*dns.AAAA).AAAA.String(), Key: msg.Path(ns.Header().Name, coredns), TTL: k.ttl}) + } + } + return svcs, nil + } + + if isDefaultNS(state.Name(), state.Zone) { + nss := k.nsAddrs(false, false, state.Zone) + var svcs []msg.Service + for _, ns := range nss { + if ns.Header().Rrtype == dns.TypeA && state.QType() == dns.TypeA { + svcs = append(svcs, msg.Service{Host: ns.(*dns.A).A.String(), Key: msg.Path(state.QName(), coredns), TTL: k.ttl}) + continue + } + if ns.Header().Rrtype == dns.TypeAAAA && state.QType() == dns.TypeAAAA { + svcs = append(svcs, msg.Service{Host: ns.(*dns.AAAA).AAAA.String(), Key: msg.Path(state.QName(), coredns), TTL: k.ttl}) + } + } + return svcs, nil + } + + s, e := k.Records(ctx, state, false) + + // SRV for external services is not yet implemented, so remove those records. + + if state.QType() != dns.TypeSRV { + return s, e + } + + internal := []msg.Service{} + for _, svc := range s { + if t, _ := svc.HostType(); t != dns.TypeCNAME { + internal = append(internal, svc) + } + } + + return internal, e +} + +// primaryZone will return the first non-reverse zone being handled by this plugin +func (k *Kubernetes) primaryZone() string { return k.Zones[k.primaryZoneIndex] } + +// Lookup implements the ServiceBackend interface. +func (k *Kubernetes) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + return k.Upstream.Lookup(ctx, state, name, typ) +} + +// IsNameError implements the ServiceBackend interface. +func (k *Kubernetes) IsNameError(err error) bool { + return err == errNoItems || err == errNsNotExposed || err == errInvalidRequest +} + +func (k *Kubernetes) getClientConfig() (*rest.Config, error) { + if k.ClientConfig != nil { + return k.ClientConfig.ClientConfig() + } + loadingRules := &clientcmd.ClientConfigLoadingRules{} + overrides := &clientcmd.ConfigOverrides{} + clusterinfo := clientcmdapi.Cluster{} + authinfo := clientcmdapi.AuthInfo{} + + // Connect to API from in cluster + if len(k.APIServerList) == 0 { + cc, err := rest.InClusterConfig() + if err != nil { + return nil, err + } + cc.ContentType = "application/vnd.kubernetes.protobuf" + return cc, err + } + + // Connect to API from out of cluster + // Only the first one is used. We will deprecate multiple endpoints later. + clusterinfo.Server = k.APIServerList[0] + + if len(k.APICertAuth) > 0 { + clusterinfo.CertificateAuthority = k.APICertAuth + } + if len(k.APIClientCert) > 0 { + authinfo.ClientCertificate = k.APIClientCert + } + if len(k.APIClientKey) > 0 { + authinfo.ClientKey = k.APIClientKey + } + + overrides.ClusterInfo = clusterinfo + overrides.AuthInfo = authinfo + clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, overrides) + + cc, err := clientConfig.ClientConfig() + if err != nil { + return nil, err + } + cc.ContentType = "application/vnd.kubernetes.protobuf" + return cc, err +} + +// InitKubeCache initializes a new Kubernetes cache. +func (k *Kubernetes) InitKubeCache(ctx context.Context) (onStart func() error, onShut func() error, err error) { + config, err := k.getClientConfig() + if err != nil { + return nil, nil, err + } + + kubeClient, err := kubernetes.NewForConfig(config) + if err != nil { + return nil, nil, fmt.Errorf("failed to create kubernetes notification controller: %q", err) + } + + if k.opts.labelSelector != nil { + var selector labels.Selector + selector, err = meta.LabelSelectorAsSelector(k.opts.labelSelector) + if err != nil { + return nil, nil, fmt.Errorf("unable to create Selector for LabelSelector '%s': %q", k.opts.labelSelector, err) + } + k.opts.selector = selector + } + + if k.opts.namespaceLabelSelector != nil { + var selector labels.Selector + selector, err = meta.LabelSelectorAsSelector(k.opts.namespaceLabelSelector) + if err != nil { + return nil, nil, fmt.Errorf("unable to create Selector for LabelSelector '%s': %q", k.opts.namespaceLabelSelector, err) + } + k.opts.namespaceSelector = selector + } + + k.opts.initPodCache = k.podMode == podModeVerified + + k.opts.zones = k.Zones + k.opts.endpointNameMode = k.endpointNameMode + + k.APIConn = newdnsController(ctx, kubeClient, k.opts) + + onStart = func() error { + go func() { + k.APIConn.Run() + }() + + timeout := 5 * time.Second + timeoutTicker := time.NewTicker(timeout) + defer timeoutTicker.Stop() + logDelay := 500 * time.Millisecond + logTicker := time.NewTicker(logDelay) + defer logTicker.Stop() + checkSyncTicker := time.NewTicker(100 * time.Millisecond) + defer checkSyncTicker.Stop() + for { + select { + case <-checkSyncTicker.C: + if k.APIConn.HasSynced() { + return nil + } + case <-logTicker.C: + log.Info("waiting for Kubernetes API before starting server") + case <-timeoutTicker.C: + log.Warning("starting server with unsynced Kubernetes API") + return nil + } + } + } + + onShut = func() error { + return k.APIConn.Stop() + } + + return onStart, onShut, err +} + +// Records looks up services in kubernetes. +func (k *Kubernetes) Records(ctx context.Context, state request.Request, exact bool) ([]msg.Service, error) { + r, e := parseRequest(state.Name(), state.Zone) + if e != nil { + return nil, e + } + if r.podOrSvc == "" { + return nil, nil + } + + if dnsutil.IsReverse(state.Name()) > 0 { + return nil, errNoItems + } + + if !k.namespaceExposed(r.namespace) { + return nil, errNsNotExposed + } + + if r.podOrSvc == Pod { + pods, err := k.findPods(r, state.Zone) + return pods, err + } + + services, err := k.findServices(r, state.Zone) + return services, err +} + +func endpointHostname(addr object.EndpointAddress, endpointNameMode bool) string { + if addr.Hostname != "" { + return addr.Hostname + } + if endpointNameMode && addr.TargetRefName != "" { + return addr.TargetRefName + } + if strings.Contains(addr.IP, ".") { + return strings.Replace(addr.IP, ".", "-", -1) + } + if strings.Contains(addr.IP, ":") { + return strings.Replace(addr.IP, ":", "-", -1) + } + return "" +} + +func (k *Kubernetes) findPods(r recordRequest, zone string) (pods []msg.Service, err error) { + if k.podMode == podModeDisabled { + return nil, errNoItems + } + + namespace := r.namespace + if !k.namespaceExposed(namespace) { + return nil, errNoItems + } + + podname := r.service + + // handle empty pod name + if podname == "" { + if k.namespaceExposed(namespace) { + // NODATA + return nil, nil + } + // NXDOMAIN + return nil, errNoItems + } + + zonePath := msg.Path(zone, coredns) + ip := "" + if strings.Count(podname, "-") == 3 && !strings.Contains(podname, "--") { + ip = strings.ReplaceAll(podname, "-", ".") + } else { + ip = strings.ReplaceAll(podname, "-", ":") + } + + if k.podMode == podModeInsecure { + if !k.namespaceExposed(namespace) { // namespace does not exist + return nil, errNoItems + } + + // If ip does not parse as an IP address, we return an error, otherwise we assume a CNAME and will try to resolve it in backend_lookup.go + if net.ParseIP(ip) == nil { + return nil, errNoItems + } + + return []msg.Service{{Key: strings.Join([]string{zonePath, Pod, namespace, podname}, "/"), Host: ip, TTL: k.ttl}}, err + } + + // PodModeVerified + err = errNoItems + + for _, p := range k.APIConn.PodIndex(ip) { + // check for matching ip and namespace + if ip == p.PodIP && match(namespace, p.Namespace) { + s := msg.Service{Key: strings.Join([]string{zonePath, Pod, namespace, podname}, "/"), Host: ip, TTL: k.ttl} + pods = append(pods, s) + + err = nil + } + } + return pods, err +} + +// findServices returns the services matching r from the cache. +func (k *Kubernetes) findServices(r recordRequest, zone string) (services []msg.Service, err error) { + if !k.namespaceExposed(r.namespace) { + return nil, errNoItems + } + + // handle empty service name + if r.service == "" { + if k.namespaceExposed(r.namespace) { + // NODATA + return nil, nil + } + // NXDOMAIN + return nil, errNoItems + } + + err = errNoItems + + var ( + endpointsListFunc func() []*object.Endpoints + endpointsList []*object.Endpoints + serviceList []*object.Service + ) + + idx := object.ServiceKey(r.service, r.namespace) + serviceList = k.APIConn.SvcIndex(idx) + endpointsListFunc = func() []*object.Endpoints { return k.APIConn.EpIndex(idx) } + + zonePath := msg.Path(zone, coredns) + for _, svc := range serviceList { + if !(match(r.namespace, svc.Namespace) && match(r.service, svc.Name)) { + continue + } + + // If "ignore empty_service" option is set and no endpoints exist, return NXDOMAIN unless + // it's a headless or externalName service (covered below). + if k.opts.ignoreEmptyService && svc.Type != api.ServiceTypeExternalName && !svc.Headless() { // serve NXDOMAIN if no endpoint is able to answer + podsCount := 0 + for _, ep := range endpointsListFunc() { + for _, eps := range ep.Subsets { + podsCount += len(eps.Addresses) + } + } + + if podsCount == 0 { + continue + } + } + + // External service + if svc.Type == api.ServiceTypeExternalName { + // External services do not have endpoints, nor can we accept port/protocol pseudo subdomains in an SRV query, so skip this service if endpoint, port, or protocol is non-empty in the request + if r.endpoint != "" || r.port != "" || r.protocol != "" { + continue + } + s := msg.Service{Key: strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/"), Host: svc.ExternalName, TTL: k.ttl} + if t, _ := s.HostType(); t == dns.TypeCNAME { + s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/") + services = append(services, s) + + err = nil + } + continue + } + + // Endpoint query or headless service + if svc.Headless() || r.endpoint != "" { + if endpointsList == nil { + endpointsList = endpointsListFunc() + } + + for _, ep := range endpointsList { + if object.EndpointsKey(svc.Name, svc.Namespace) != ep.Index { + continue + } + + for _, eps := range ep.Subsets { + for _, addr := range eps.Addresses { + // See comments in parse.go parseRequest about the endpoint handling. + if r.endpoint != "" { + if !match(r.endpoint, endpointHostname(addr, k.endpointNameMode)) { + continue + } + } + + for _, p := range eps.Ports { + if !(matchPortAndProtocol(r.port, p.Name, r.protocol, p.Protocol)) { + continue + } + s := msg.Service{Host: addr.IP, Port: int(p.Port), TTL: k.ttl} + s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name, endpointHostname(addr, k.endpointNameMode)}, "/") + + err = nil + + services = append(services, s) + } + } + } + } + continue + } + + // ClusterIP service + for _, p := range svc.Ports { + if !(matchPortAndProtocol(r.port, p.Name, r.protocol, string(p.Protocol))) { + continue + } + + err = nil + + for _, ip := range svc.ClusterIPs { + s := msg.Service{Host: ip, Port: int(p.Port), TTL: k.ttl} + s.Key = strings.Join([]string{zonePath, Svc, svc.Namespace, svc.Name}, "/") + services = append(services, s) + } + } + } + return services, err +} + +// Serial return the SOA serial. +func (k *Kubernetes) Serial(state request.Request) uint32 { return uint32(k.APIConn.Modified(false)) } + +// MinTTL returns the minimal TTL. +func (k *Kubernetes) MinTTL(state request.Request) uint32 { return k.ttl } + +// match checks if a and b are equal. +func match(a, b string) bool { + return strings.EqualFold(a, b) +} + +// matchPortAndProtocol matches port and protocol, permitting the 'a' inputs to be wild +func matchPortAndProtocol(aPort, bPort, aProtocol, bProtocol string) bool { + return (match(aPort, bPort) || aPort == "") && (match(aProtocol, bProtocol) || aProtocol == "") +} + +const coredns = "c" // used as a fake key prefix in msg.Service diff --git a/plugin/kubernetes/kubernetes_apex_test.go b/plugin/kubernetes/kubernetes_apex_test.go new file mode 100644 index 0000000..7531e21 --- /dev/null +++ b/plugin/kubernetes/kubernetes_apex_test.go @@ -0,0 +1,92 @@ +package kubernetes + +import ( + "context" + "net" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var kubeApexCases = []test.Case{ + { + Qname: "cluster.local.", Qtype: dns.TypeSOA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "cluster.local.", Qtype: dns.TypeHINFO, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "cluster.local.", Qtype: dns.TypeNS, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.NS("cluster.local. 5 IN NS ns.dns.cluster.local."), + }, + Extra: []dns.RR{ + test.A("ns.dns.cluster.local. 5 IN A 127.0.0.1"), + }, + }, + { + Qname: "cluster.local.", Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "cluster.local.", Qtype: dns.TypeAAAA, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, + { + Qname: "cluster.local.", Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1499347823 7200 1800 86400 5"), + }, + }, +} + +func TestServeDNSApex(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Next = test.NextHandler(dns.RcodeSuccess, nil) + k.localIPs = []net.IP{net.ParseIP("127.0.0.1")} + ctx := context.TODO() + + for i, tc := range kubeApexCases { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d, expected no error, got %v", i, err) + return + } + if tc.Error != nil { + continue + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d, got nil message and no error ford", i) + } + + if err := test.SortAndCheck(resp, tc); err != nil { + t.Errorf("Test %d: %v", i, err) + } + } +} diff --git a/plugin/kubernetes/kubernetes_test.go b/plugin/kubernetes/kubernetes_test.go new file mode 100644 index 0000000..acdfd4c --- /dev/null +++ b/plugin/kubernetes/kubernetes_test.go @@ -0,0 +1,367 @@ +package kubernetes + +import ( + "context" + "net" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestEndpointHostname(t *testing.T) { + var tests = []struct { + ip string + hostname string + expected string + podName string + endpointNameMode bool + }{ + {"10.11.12.13", "", "10-11-12-13", "", false}, + {"10.11.12.13", "epname", "epname", "", false}, + {"10.11.12.13", "", "10-11-12-13", "hello-abcde", false}, + {"10.11.12.13", "epname", "epname", "hello-abcde", false}, + {"10.11.12.13", "epname", "epname", "hello-abcde", true}, + {"10.11.12.13", "", "hello-abcde", "hello-abcde", true}, + } + for _, test := range tests { + result := endpointHostname(object.EndpointAddress{IP: test.ip, Hostname: test.hostname, TargetRefName: test.podName}, test.endpointNameMode) + if result != test.expected { + t.Errorf("Expected endpoint name for (ip:%v hostname:%v) to be '%v', but got '%v'", test.ip, test.hostname, test.expected, result) + } + } +} + +type APIConnServiceTest struct{} + +func (APIConnServiceTest) HasSynced() bool { return true } +func (APIConnServiceTest) Run() {} +func (APIConnServiceTest) Stop() error { return nil } +func (APIConnServiceTest) PodIndex(string) []*object.Pod { return nil } +func (APIConnServiceTest) SvcIndexReverse(string) []*object.Service { return nil } +func (APIConnServiceTest) SvcExtIndexReverse(string) []*object.Service { return nil } +func (APIConnServiceTest) EpIndexReverse(string) []*object.Endpoints { return nil } +func (APIConnServiceTest) Modified(bool) int64 { return 0 } + +func (APIConnServiceTest) SvcIndex(string) []*object.Service { + svcs := []*object.Service{ + { + Name: "svc1", + Namespace: "testns", + ClusterIPs: []string{"10.0.0.1"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + { + Name: "svc-dual-stack", + Namespace: "testns", + ClusterIPs: []string{"10.0.0.2", "10::2"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + { + Name: "hdls1", + Namespace: "testns", + ClusterIPs: []string{api.ClusterIPNone}, + }, + { + Name: "external", + Namespace: "testns", + ExternalName: "coredns.io", + Type: api.ServiceTypeExternalName, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + } + return svcs +} + +func (APIConnServiceTest) ServiceList() []*object.Service { + svcs := []*object.Service{ + { + Name: "svc1", + Namespace: "testns", + ClusterIPs: []string{"10.0.0.1"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + { + Name: "svc-dual-stack", + Namespace: "testns", + ClusterIPs: []string{"10.0.0.2", "10::2"}, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + { + Name: "hdls1", + Namespace: "testns", + ClusterIPs: []string{api.ClusterIPNone}, + }, + { + Name: "external", + Namespace: "testns", + ExternalName: "coredns.io", + Type: api.ServiceTypeExternalName, + Ports: []api.ServicePort{ + {Name: "http", Protocol: "tcp", Port: 80}, + }, + }, + } + return svcs +} + +func (APIConnServiceTest) EpIndex(string) []*object.Endpoints { + eps := []*object.Endpoints{ + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.1", Hostname: "ep1a"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("svc1", "testns"), + }, + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.2"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "hdls1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("hdls1", "testns"), + }, + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "10.9.8.7", NodeName: "test.node.foo.bar"}, + }, + }, + }, + }, + } + return eps +} + +func (APIConnServiceTest) EndpointsList() []*object.Endpoints { + eps := []*object.Endpoints{ + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.1", Hostname: "ep1a"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("svc1", "testns"), + }, + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.2"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "hdls1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("hdls1", "testns"), + }, + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "172.0.0.2"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "hdls1-slice2", + Namespace: "testns", + Index: object.EndpointsKey("hdls1", "testns"), + }, + { + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "10.9.8.7", NodeName: "test.node.foo.bar"}, + }, + }, + }, + }, + } + return eps +} + +func (APIConnServiceTest) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { + return &api.Node{ + ObjectMeta: meta.ObjectMeta{ + Name: "test.node.foo.bar", + }, + }, nil +} + +func (APIConnServiceTest) GetNamespaceByName(name string) (*object.Namespace, error) { + return &object.Namespace{ + Name: name, + }, nil +} + +func TestServices(t *testing.T) { + k := New([]string{"interwebs.test."}) + k.APIConn = &APIConnServiceTest{} + + type svcAns struct { + host string + key string + } + type svcTest struct { + qname string + qtype uint16 + answer []svcAns + } + tests := []svcTest{ + // Cluster IP Services + {qname: "svc1.testns.svc.interwebs.test.", qtype: dns.TypeA, answer: []svcAns{{host: "10.0.0.1", key: "/" + coredns + "/test/interwebs/svc/testns/svc1"}}}, + {qname: "_http._tcp.svc1.testns.svc.interwebs.test.", qtype: dns.TypeSRV, answer: []svcAns{{host: "10.0.0.1", key: "/" + coredns + "/test/interwebs/svc/testns/svc1"}}}, + {qname: "ep1a.svc1.testns.svc.interwebs.test.", qtype: dns.TypeA, answer: []svcAns{{host: "172.0.0.1", key: "/" + coredns + "/test/interwebs/svc/testns/svc1/ep1a"}}}, + + // Dual-Stack Cluster IP Service + { + qname: "_http._tcp.svc-dual-stack.testns.svc.interwebs.test.", + qtype: dns.TypeSRV, + answer: []svcAns{ + {host: "10.0.0.2", key: "/" + coredns + "/test/interwebs/svc/testns/svc-dual-stack"}, + {host: "10::2", key: "/" + coredns + "/test/interwebs/svc/testns/svc-dual-stack"}, + }, + }, + + // External Services + {qname: "external.testns.svc.interwebs.test.", qtype: dns.TypeCNAME, answer: []svcAns{{host: "coredns.io", key: "/" + coredns + "/test/interwebs/svc/testns/external"}}}, + + // Headless Services + {qname: "hdls1.testns.svc.interwebs.test.", qtype: dns.TypeA, answer: []svcAns{{host: "172.0.0.2", key: "/" + coredns + "/test/interwebs/svc/testns/hdls1/172-0-0-2"}}}, + } + + for i, test := range tests { + state := request.Request{ + Req: &dns.Msg{Question: []dns.Question{{Name: test.qname, Qtype: test.qtype}}}, + Zone: "interwebs.test.", // must match from k.Zones[0] + } + svcs, e := k.Services(context.TODO(), state, false, plugin.Options{}) + if e != nil { + t.Errorf("Test %d: got error '%v'", i, e) + continue + } + if len(svcs) != len(test.answer) { + t.Errorf("Test %d, expected %v answer, got %v", i, len(test.answer), len(svcs)) + continue + } + + for j := range svcs { + if test.answer[j].host != svcs[j].Host { + t.Errorf("Test %d, expected host '%v', got '%v'", i, test.answer[j].host, svcs[j].Host) + } + if test.answer[j].key != svcs[j].Key { + t.Errorf("Test %d, expected key '%v', got '%v'", i, test.answer[j].key, svcs[j].Key) + } + } + } +} + +func TestServicesAuthority(t *testing.T) { + k := New([]string{"interwebs.test."}) + k.APIConn = &APIConnServiceTest{} + + type svcAns struct { + host string + key string + } + type svcTest struct { + localIPs []net.IP + qname string + qtype uint16 + answer []svcAns + } + tests := []svcTest{ + {localIPs: []net.IP{net.ParseIP("1.2.3.4")}, qname: "ns.dns.interwebs.test.", qtype: dns.TypeA, answer: []svcAns{{host: "1.2.3.4", key: "/" + coredns + "/test/interwebs/dns/ns"}}}, + {localIPs: []net.IP{net.ParseIP("1.2.3.4")}, qname: "ns.dns.interwebs.test.", qtype: dns.TypeAAAA}, + {localIPs: []net.IP{net.ParseIP("1:2::3:4")}, qname: "ns.dns.interwebs.test.", qtype: dns.TypeA}, + {localIPs: []net.IP{net.ParseIP("1:2::3:4")}, qname: "ns.dns.interwebs.test.", qtype: dns.TypeAAAA, answer: []svcAns{{host: "1:2::3:4", key: "/" + coredns + "/test/interwebs/dns/ns"}}}, + { + localIPs: []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1:2::3:4")}, + qname: "ns.dns.interwebs.test.", + qtype: dns.TypeNS, answer: []svcAns{ + {host: "1.2.3.4", key: "/" + coredns + "/test/interwebs/dns/ns"}, + {host: "1:2::3:4", key: "/" + coredns + "/test/interwebs/dns/ns"}, + }, + }, + } + + for i, test := range tests { + k.localIPs = test.localIPs + + state := request.Request{ + Req: &dns.Msg{Question: []dns.Question{{Name: test.qname, Qtype: test.qtype}}}, + Zone: k.Zones[0], + } + svcs, e := k.Services(context.TODO(), state, false, plugin.Options{}) + if e != nil { + t.Errorf("Test %d: got error '%v'", i, e) + continue + } + if test.answer != nil && len(svcs) != len(test.answer) { + t.Errorf("Test %d, expected 1 answer, got %v", i, len(svcs)) + continue + } + if test.answer == nil && len(svcs) != 0 { + t.Errorf("Test %d, expected no answer, got %v", i, len(svcs)) + continue + } + + if test.answer == nil && len(svcs) == 0 { + continue + } + + for i, answer := range test.answer { + if answer.host != svcs[i].Host { + t.Errorf("Test %d, expected host '%v', got '%v'", i, answer.host, svcs[i].Host) + } + if answer.key != svcs[i].Key { + t.Errorf("Test %d, expected key '%v', got '%v'", i, answer.key, svcs[i].Key) + } + } + } +} diff --git a/plugin/kubernetes/local.go b/plugin/kubernetes/local.go new file mode 100644 index 0000000..a754f21 --- /dev/null +++ b/plugin/kubernetes/local.go @@ -0,0 +1,37 @@ +package kubernetes + +import ( + "net" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +// boundIPs returns the list of non-loopback IPs that CoreDNS is bound to +func boundIPs(c *caddy.Controller) (ips []net.IP) { + conf := dnsserver.GetConfig(c) + hosts := conf.ListenHosts + if hosts == nil || hosts[0] == "" { + hosts = nil + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil + } + for _, addr := range addrs { + hosts = append(hosts, addr.String()) + } + } + for _, host := range hosts { + ip, _, _ := net.ParseCIDR(host) + ip4 := ip.To4() + if ip4 != nil && !ip4.IsLoopback() { + ips = append(ips, ip4) + continue + } + ip6 := ip.To16() + if ip6 != nil && !ip6.IsLoopback() { + ips = append(ips, ip6) + } + } + return ips +} diff --git a/plugin/kubernetes/log_test.go b/plugin/kubernetes/log_test.go new file mode 100644 index 0000000..b8b7b74 --- /dev/null +++ b/plugin/kubernetes/log_test.go @@ -0,0 +1,5 @@ +package kubernetes + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/kubernetes/logger.go b/plugin/kubernetes/logger.go new file mode 100644 index 0000000..ac9fe80 --- /dev/null +++ b/plugin/kubernetes/logger.go @@ -0,0 +1,38 @@ +package kubernetes + +import ( + clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/go-logr/logr" +) + +// loggerAdapter is a simple wrapper around CoreDNS plugin logger made to implement logr.LogSink interface, which is used +// as part of klog library for logging in Kubernetes client. By using this adapter CoreDNS is able to log messages/errors from +// kubernetes client in a CoreDNS logging format +type loggerAdapter struct { + clog.P +} + +func (l *loggerAdapter) Init(_ logr.RuntimeInfo) { +} + +func (l *loggerAdapter) Enabled(_ int) bool { + // verbosity is controlled inside klog library, we do not need to do anything here + return true +} + +func (l *loggerAdapter) Info(_ int, msg string, _ ...interface{}) { + l.P.Info(msg) +} + +func (l *loggerAdapter) Error(_ error, msg string, _ ...interface{}) { + l.P.Error(msg) +} + +func (l *loggerAdapter) WithValues(_ ...interface{}) logr.LogSink { + return l +} + +func (l *loggerAdapter) WithName(_ string) logr.LogSink { + return l +} diff --git a/plugin/kubernetes/metadata.go b/plugin/kubernetes/metadata.go new file mode 100644 index 0000000..36e2f9a --- /dev/null +++ b/plugin/kubernetes/metadata.go @@ -0,0 +1,62 @@ +package kubernetes + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/request" +) + +// Metadata implements the metadata.Provider interface. +func (k *Kubernetes) Metadata(ctx context.Context, state request.Request) context.Context { + pod := k.podWithIP(state.IP()) + if pod != nil { + metadata.SetValueFunc(ctx, "kubernetes/client-namespace", func() string { + return pod.Namespace + }) + + metadata.SetValueFunc(ctx, "kubernetes/client-pod-name", func() string { + return pod.Name + }) + } + + zone := plugin.Zones(k.Zones).Matches(state.Name()) + if zone == "" { + return ctx + } + // possible optimization: cache r so it doesn't need to be calculated again in ServeDNS + r, err := parseRequest(state.Name(), zone) + if err != nil { + metadata.SetValueFunc(ctx, "kubernetes/parse-error", func() string { + return err.Error() + }) + return ctx + } + + metadata.SetValueFunc(ctx, "kubernetes/port-name", func() string { + return r.port + }) + + metadata.SetValueFunc(ctx, "kubernetes/protocol", func() string { + return r.protocol + }) + + metadata.SetValueFunc(ctx, "kubernetes/endpoint", func() string { + return r.endpoint + }) + + metadata.SetValueFunc(ctx, "kubernetes/service", func() string { + return r.service + }) + + metadata.SetValueFunc(ctx, "kubernetes/namespace", func() string { + return r.namespace + }) + + metadata.SetValueFunc(ctx, "kubernetes/kind", func() string { + return r.podOrSvc + }) + + return ctx +} diff --git a/plugin/kubernetes/metadata_test.go b/plugin/kubernetes/metadata_test.go new file mode 100644 index 0000000..009c533 --- /dev/null +++ b/plugin/kubernetes/metadata_test.go @@ -0,0 +1,155 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var metadataCases = []struct { + Qname string + Qtype uint16 + RemoteIP string + Md map[string]string +}{ + { + Qname: "foo.bar.notapod.cluster.local.", Qtype: dns.TypeA, + Md: map[string]string{ + "kubernetes/parse-error": "invalid query name", + }, + }, + { + Qname: "10-240-0-1.podns.pod.cluster.local.", Qtype: dns.TypeA, + Md: map[string]string{ + "kubernetes/endpoint": "", + "kubernetes/kind": "pod", + "kubernetes/namespace": "podns", + "kubernetes/port-name": "", + "kubernetes/protocol": "", + "kubernetes/service": "10-240-0-1", + }, + }, + { + Qname: "s.ns.svc.cluster.local.", Qtype: dns.TypeA, + Md: map[string]string{ + "kubernetes/endpoint": "", + "kubernetes/kind": "svc", + "kubernetes/namespace": "ns", + "kubernetes/port-name": "", + "kubernetes/protocol": "", + "kubernetes/service": "s", + }, + }, + { + Qname: "s.ns.svc.cluster.local.", Qtype: dns.TypeA, + RemoteIP: "10.10.10.10", + Md: map[string]string{ + "kubernetes/endpoint": "", + "kubernetes/kind": "svc", + "kubernetes/namespace": "ns", + "kubernetes/port-name": "", + "kubernetes/protocol": "", + "kubernetes/service": "s", + }, + }, + { + Qname: "_http._tcp.s.ns.svc.cluster.local.", Qtype: dns.TypeSRV, + RemoteIP: "10.10.10.10", + Md: map[string]string{ + "kubernetes/endpoint": "", + "kubernetes/kind": "svc", + "kubernetes/namespace": "ns", + "kubernetes/port-name": "http", + "kubernetes/protocol": "tcp", + "kubernetes/service": "s", + }, + }, + { + Qname: "ep.s.ns.svc.cluster.local.", Qtype: dns.TypeA, + RemoteIP: "10.10.10.10", + Md: map[string]string{ + "kubernetes/endpoint": "ep", + "kubernetes/kind": "svc", + "kubernetes/namespace": "ns", + "kubernetes/port-name": "", + "kubernetes/protocol": "", + "kubernetes/service": "s", + }, + }, + { + Qname: "example.com.", Qtype: dns.TypeA, + RemoteIP: "10.10.10.10", + Md: map[string]string{}, + }, +} + +func mapsDiffer(a, b map[string]string) bool { + if len(a) != len(b) { + return true + } + + for k, va := range a { + vb, ok := b[k] + if !ok || va != vb { + return true + } + } + return false +} + +func TestMetadata(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + + for i, tc := range metadataCases { + ctx := metadata.ContextWithMetadata(context.Background()) + state := request.Request{ + Req: &dns.Msg{Question: []dns.Question{{Name: tc.Qname, Qtype: tc.Qtype}}}, + Zone: ".", + W: &test.ResponseWriter{RemoteIP: tc.RemoteIP}, + } + + k.Metadata(ctx, state) + + md := make(map[string]string) + for _, l := range metadata.Labels(ctx) { + md[l] = metadata.ValueFunc(ctx, l)() + } + if mapsDiffer(tc.Md, md) { + t.Errorf("Case %d expected metadata %v and got %v", i, tc.Md, md) + } + } +} + +func TestMetadataPodsVerified(t *testing.T) { + k := New([]string{"cluster.local."}) + k.podMode = podModeVerified + k.APIConn = &APIConnServeTest{} + + ctx := metadata.ContextWithMetadata(context.Background()) + state := request.Request{ + Req: &dns.Msg{Question: []dns.Question{{Name: "example.com.", Qtype: dns.TypeA}}}, + Zone: ".", + W: &test.ResponseWriter{}, + } + + k.Metadata(ctx, state) + + expect := map[string]string{ + "kubernetes/client-namespace": "podns", + "kubernetes/client-pod-name": "foo", + } + + md := make(map[string]string) + for _, l := range metadata.Labels(ctx) { + md[l] = metadata.ValueFunc(ctx, l)() + } + if mapsDiffer(expect, md) { + t.Errorf("Expected metadata %v and got %v", expect, md) + } +} diff --git a/plugin/kubernetes/metrics.go b/plugin/kubernetes/metrics.go new file mode 100644 index 0000000..d6927cd --- /dev/null +++ b/plugin/kubernetes/metrics.go @@ -0,0 +1,74 @@ +package kubernetes + +import ( + "context" + "net/url" + "time" + + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "k8s.io/client-go/tools/metrics" +) + +var ( + // requestLatency measures K8s rest client requests latency grouped by verb and host. + requestLatency = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "kubernetes", + Name: "rest_client_request_duration_seconds", + Help: "Request latency in seconds. Broken down by verb and host.", + Buckets: prometheus.DefBuckets, + }, + []string{"verb", "host"}, + ) + + // rateLimiterLatency measures K8s rest client rate limiter latency grouped by verb and host. + rateLimiterLatency = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "kubernetes", + Name: "rest_client_rate_limiter_duration_seconds", + Help: "Client side rate limiter latency in seconds. Broken down by verb and host.", + Buckets: prometheus.DefBuckets, + }, + []string{"verb", "host"}, + ) + + // requestResult measures K8s rest client request metrics grouped by status code, method & host. + requestResult = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "kubernetes", + Name: "rest_client_requests_total", + Help: "Number of HTTP requests, partitioned by status code, method, and host.", + }, + []string{"code", "method", "host"}, + ) +) + +func init() { + metrics.Register(metrics.RegisterOpts{ + RequestLatency: &latencyAdapter{m: requestLatency}, + RateLimiterLatency: &latencyAdapter{m: rateLimiterLatency}, + RequestResult: &resultAdapter{requestResult}, + }) +} + +type latencyAdapter struct { + m *prometheus.HistogramVec +} + +func (l *latencyAdapter) Observe(_ context.Context, verb string, u url.URL, latency time.Duration) { + l.m.WithLabelValues(verb, u.Host).Observe(latency.Seconds()) +} + +type resultAdapter struct { + m *prometheus.CounterVec +} + +func (r *resultAdapter) Increment(_ context.Context, code, method, host string) { + r.m.WithLabelValues(code, method, host).Inc() +} diff --git a/plugin/kubernetes/metrics_test.backup b/plugin/kubernetes/metrics_test.backup new file mode 100644 index 0000000..8274eef --- /dev/null +++ b/plugin/kubernetes/metrics_test.backup @@ -0,0 +1,203 @@ +package kubernetes + +import ( + "strings" + "testing" + "time" + + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/prometheus/client_golang/prometheus/testutil" + api "k8s.io/api/core/v1" + discovery "k8s.io/api/discovery/v1beta1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/cache" +) + +const ( + namespace = "testns" +) + +var expected = ` + # HELP coredns_kubernetes_dns_programming_duration_seconds Histogram of the time (in seconds) it took to program a dns instance. + # TYPE coredns_kubernetes_dns_programming_duration_seconds histogram + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.001"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.002"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.004"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.008"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.016"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.032"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.064"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.128"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.256"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="0.512"} 0 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="1.024"} 1 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="2.048"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="4.096"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="8.192"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="16.384"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="32.768"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="65.536"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="131.072"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="262.144"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="524.288"} 2 + coredns_kubernetes_dns_programming_duration_seconds_bucket{service_kind="headless_with_selector",le="+Inf"} 2 + coredns_kubernetes_dns_programming_duration_seconds_sum{service_kind="headless_with_selector"} 3 + coredns_kubernetes_dns_programming_duration_seconds_count{service_kind="headless_with_selector"} 2 + ` + +func TestDNSProgrammingLatencyEndpointSlices(t *testing.T) { + now := time.Now() + + svcIdx := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, cache.Indexers{svcNameNamespaceIndex: svcNameNamespaceIndexFunc}) + epIdx := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, cache.Indexers{}) + + dns := dnsControl{svcLister: svcIdx} + svcProc := object.DefaultProcessor(object.ToService, nil)(svcIdx, cache.ResourceEventHandlerFuncs{}) + epProc := object.DefaultProcessor(object.EndpointSliceToEndpoints, dns.EndpointSliceLatencyRecorder())(epIdx, cache.ResourceEventHandlerFuncs{}) + + object.DurationSinceFunc = func(t time.Time) time.Duration { + return now.Sub(t) + } + object.DNSProgrammingLatency.Reset() + + endpoints1 := []discovery.Endpoint{{ + Addresses: []string{"1.2.3.4"}, + }} + + endpoints2 := []discovery.Endpoint{{ + Addresses: []string{"1.2.3.45"}, + }} + + createService(t, svcProc, "my-service", api.ClusterIPNone) + createEndpointSlice(t, epProc, "my-service", now.Add(-2*time.Second), endpoints1) + updateEndpointSlice(t, epProc, "my-service", now.Add(-1*time.Second), endpoints2) + + createEndpointSlice(t, epProc, "endpoints-no-service", now.Add(-4*time.Second), nil) + + createService(t, svcProc, "clusterIP-service", "10.40.0.12") + createEndpointSlice(t, epProc, "clusterIP-service", now.Add(-8*time.Second), nil) + + createService(t, svcProc, "headless-no-annotation", api.ClusterIPNone) + createEndpointSlice(t, epProc, "headless-no-annotation", nil, nil) + + createService(t, svcProc, "headless-wrong-annotation", api.ClusterIPNone) + createEndpointSlice(t, epProc, "headless-wrong-annotation", "wrong-value", nil) + + if err := testutil.CollectAndCompare(object.DNSProgrammingLatency, strings.NewReader(expected)); err != nil { + t.Error(err) + } +} + +func TestDnsProgrammingLatencyEndpoints(t *testing.T) { + now := time.Now() + + svcIdx := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, cache.Indexers{svcNameNamespaceIndex: svcNameNamespaceIndexFunc}) + epIdx := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, cache.Indexers{}) + + dns := dnsControl{svcLister: svcIdx} + svcProc := object.DefaultProcessor(object.ToService, nil)(svcIdx, cache.ResourceEventHandlerFuncs{}) + epProc := object.DefaultProcessor(object.ToEndpoints, dns.EndpointsLatencyRecorder())(epIdx, cache.ResourceEventHandlerFuncs{}) + + object.DurationSinceFunc = func(t time.Time) time.Duration { + return now.Sub(t) + } + object.DNSProgrammingLatency.Reset() + + subset1 := []api.EndpointSubset{{ + Addresses: []api.EndpointAddress{{IP: "1.2.3.4", Hostname: "foo"}}, + }} + + subset2 := []api.EndpointSubset{{ + Addresses: []api.EndpointAddress{{IP: "1.2.3.5", Hostname: "foo"}}, + }} + + createService(t, svcProc, "my-service", api.ClusterIPNone) + createEndpoints(t, epProc, "my-service", now.Add(-2*time.Second), subset1) + updateEndpoints(t, epProc, "my-service", now.Add(-1*time.Second), subset2) + + createEndpoints(t, epProc, "endpoints-no-service", now.Add(-4*time.Second), nil) + + createService(t, svcProc, "clusterIP-service", "10.40.0.12") + createEndpoints(t, epProc, "clusterIP-service", now.Add(-8*time.Second), nil) + + createService(t, svcProc, "headless-no-annotation", api.ClusterIPNone) + createEndpoints(t, epProc, "headless-no-annotation", nil, nil) + + createService(t, svcProc, "headless-wrong-annotation", api.ClusterIPNone) + createEndpoints(t, epProc, "headless-wrong-annotation", "wrong-value", nil) + + if err := testutil.CollectAndCompare(object.DNSProgrammingLatency, strings.NewReader(expected)); err != nil { + t.Error(err) + } +} + +func buildEndpoints(name string, lastChangeTriggerTime interface{}, subsets []api.EndpointSubset) *api.Endpoints { + annotations := make(map[string]string) + switch v := lastChangeTriggerTime.(type) { + case string: + annotations[api.EndpointsLastChangeTriggerTime] = v + case time.Time: + annotations[api.EndpointsLastChangeTriggerTime] = v.Format(time.RFC3339Nano) + } + return &api.Endpoints{ + ObjectMeta: meta.ObjectMeta{Namespace: namespace, Name: name, Annotations: annotations}, + Subsets: subsets, + } +} + +func buildEndpointSlice(name string, lastChangeTriggerTime interface{}, endpoints []discovery.Endpoint) *discovery.EndpointSlice { + annotations := make(map[string]string) + switch v := lastChangeTriggerTime.(type) { + case string: + annotations[api.EndpointsLastChangeTriggerTime] = v + case time.Time: + annotations[api.EndpointsLastChangeTriggerTime] = v.Format(time.RFC3339Nano) + } + return &discovery.EndpointSlice{ + ObjectMeta: meta.ObjectMeta{ + Namespace: namespace, Name: name + "-12345", + Labels: map[string]string{discovery.LabelServiceName: name}, + Annotations: annotations, + }, + Endpoints: endpoints, + } +} + +func createEndpoints(t *testing.T, processor cache.ProcessFunc, name string, triggerTime interface{}, subsets []api.EndpointSubset) { + err := processor(cache.Deltas{{Type: cache.Added, Object: buildEndpoints(name, triggerTime, subsets)}}) + if err != nil { + t.Fatal(err) + } +} + +func updateEndpoints(t *testing.T, processor cache.ProcessFunc, name string, triggerTime interface{}, subsets []api.EndpointSubset) { + err := processor(cache.Deltas{{Type: cache.Updated, Object: buildEndpoints(name, triggerTime, subsets)}}) + if err != nil { + t.Fatal(err) + } +} + +func createEndpointSlice(t *testing.T, processor cache.ProcessFunc, name string, triggerTime interface{}, endpoints []discovery.Endpoint) { + err := processor(cache.Deltas{{Type: cache.Added, Object: buildEndpointSlice(name, triggerTime, endpoints)}}) + if err != nil { + t.Fatal(err) + } +} + +func updateEndpointSlice(t *testing.T, processor cache.ProcessFunc, name string, triggerTime interface{}, endpoints []discovery.Endpoint) { + err := processor(cache.Deltas{{Type: cache.Updated, Object: buildEndpointSlice(name, triggerTime, endpoints)}}) + if err != nil { + t.Fatal(err) + } +} + +func createService(t *testing.T, processor cache.ProcessFunc, name string, clusterIp string) { + obj := &api.Service{ + ObjectMeta: meta.ObjectMeta{Namespace: namespace, Name: name}, + Spec: api.ServiceSpec{ClusterIP: clusterIp}, + } + err := processor(cache.Deltas{{Type: cache.Added, Object: obj}}) + if err != nil { + t.Fatal(err) + } +} diff --git a/plugin/kubernetes/namespace.go b/plugin/kubernetes/namespace.go new file mode 100644 index 0000000..3e90bab --- /dev/null +++ b/plugin/kubernetes/namespace.go @@ -0,0 +1,24 @@ +package kubernetes + +// filteredNamespaceExists checks if namespace exists in this cluster +// according to any `namespace_labels` plugin configuration specified. +// Returns true even for namespaces not exposed by plugin configuration, +// see namespaceExposed. +func (k *Kubernetes) filteredNamespaceExists(namespace string) bool { + _, err := k.APIConn.GetNamespaceByName(namespace) + return err == nil +} + +// configuredNamespace returns true when the namespace is exposed through the plugin +// `namespaces` configuration. +func (k *Kubernetes) configuredNamespace(namespace string) bool { + _, ok := k.Namespaces[namespace] + if len(k.Namespaces) > 0 && !ok { + return false + } + return true +} + +func (k *Kubernetes) namespaceExposed(namespace string) bool { + return k.configuredNamespace(namespace) && k.filteredNamespaceExists(namespace) +} diff --git a/plugin/kubernetes/namespace_test.go b/plugin/kubernetes/namespace_test.go new file mode 100644 index 0000000..c302b42 --- /dev/null +++ b/plugin/kubernetes/namespace_test.go @@ -0,0 +1,72 @@ +package kubernetes + +import ( + "testing" +) + +func TestFilteredNamespaceExists(t *testing.T) { + tests := []struct { + expected bool + kubernetesNamespaces map[string]struct{} + testNamespace string + }{ + {true, map[string]struct{}{}, "foobar"}, + {false, map[string]struct{}{}, "nsnoexist"}, + } + + k := Kubernetes{} + k.APIConn = &APIConnServeTest{} + for i, test := range tests { + k.Namespaces = test.kubernetesNamespaces + actual := k.filteredNamespaceExists(test.testNamespace) + if actual != test.expected { + t.Errorf("Test %d failed. Filtered namespace %s was expected to exist", i, test.testNamespace) + } + } +} + +func TestNamespaceExposed(t *testing.T) { + tests := []struct { + expected bool + kubernetesNamespaces map[string]struct{} + testNamespace string + }{ + {true, map[string]struct{}{"foobar": {}}, "foobar"}, + {false, map[string]struct{}{"foobar": {}}, "nsnoexist"}, + {true, map[string]struct{}{}, "foobar"}, + {true, map[string]struct{}{}, "nsnoexist"}, + } + + k := Kubernetes{} + k.APIConn = &APIConnServeTest{} + for i, test := range tests { + k.Namespaces = test.kubernetesNamespaces + actual := k.configuredNamespace(test.testNamespace) + if actual != test.expected { + t.Errorf("Test %d failed. Namespace %s was expected to be exposed", i, test.testNamespace) + } + } +} + +func TestNamespaceValid(t *testing.T) { + tests := []struct { + expected bool + kubernetesNamespaces map[string]struct{} + testNamespace string + }{ + {true, map[string]struct{}{"foobar": {}}, "foobar"}, + {false, map[string]struct{}{"foobar": {}}, "nsnoexist"}, + {true, map[string]struct{}{}, "foobar"}, + {false, map[string]struct{}{}, "nsnoexist"}, + } + + k := Kubernetes{} + k.APIConn = &APIConnServeTest{} + for i, test := range tests { + k.Namespaces = test.kubernetesNamespaces + actual := k.namespaceExposed(test.testNamespace) + if actual != test.expected { + t.Errorf("Test %d failed. Namespace %s was expected to be valid", i, test.testNamespace) + } + } +} diff --git a/plugin/kubernetes/ns.go b/plugin/kubernetes/ns.go new file mode 100644 index 0000000..eb40c34 --- /dev/null +++ b/plugin/kubernetes/ns.go @@ -0,0 +1,103 @@ +package kubernetes + +import ( + "net" + "strings" + + "github.com/miekg/dns" +) + +func isDefaultNS(name, zone string) bool { + return strings.Index(name, defaultNSName) == 0 && strings.Index(name, zone) == len(defaultNSName) +} + +// nsAddrs returns the A or AAAA records for the CoreDNS service in the cluster. If the service cannot be found, +// it returns a record for the local address of the machine we're running on. +func (k *Kubernetes) nsAddrs(external, headless bool, zone string) []dns.RR { + var ( + svcNames []string + svcIPs []net.IP + foundEndpoint bool + ) + + // Find the CoreDNS Endpoints + for _, localIP := range k.localIPs { + endpoints := k.APIConn.EpIndexReverse(localIP.String()) + + // Collect IPs for all Services of the Endpoints + for _, endpoint := range endpoints { + foundEndpoint = true + svcs := k.APIConn.SvcIndex(endpoint.Index) + for _, svc := range svcs { + if external { + svcName := strings.Join([]string{svc.Name, svc.Namespace, zone}, ".") + + if headless && svc.Headless() { + for _, s := range endpoint.Subsets { + for _, a := range s.Addresses { + svcNames = append(svcNames, endpointHostname(a, k.endpointNameMode)+"."+svcName) + svcIPs = append(svcIPs, net.ParseIP(a.IP)) + } + } + } else { + for _, exIP := range svc.ExternalIPs { + svcNames = append(svcNames, svcName) + svcIPs = append(svcIPs, net.ParseIP(exIP)) + } + } + + continue + } + svcName := strings.Join([]string{svc.Name, svc.Namespace, Svc, zone}, ".") + if svc.Headless() { + // For a headless service, use the endpoints IPs + for _, s := range endpoint.Subsets { + for _, a := range s.Addresses { + svcNames = append(svcNames, endpointHostname(a, k.endpointNameMode)+"."+svcName) + svcIPs = append(svcIPs, net.ParseIP(a.IP)) + } + } + } else { + for _, clusterIP := range svc.ClusterIPs { + svcNames = append(svcNames, svcName) + svcIPs = append(svcIPs, net.ParseIP(clusterIP)) + } + } + } + } + } + + // If no CoreDNS endpoints were found, use the localIPs directly + if !foundEndpoint { + svcIPs = make([]net.IP, len(k.localIPs)) + svcNames = make([]string, len(k.localIPs)) + for i, localIP := range k.localIPs { + svcNames[i] = defaultNSName + zone + svcIPs[i] = localIP + } + } + + // Create an RR slice of collected IPs + rrs := make([]dns.RR, len(svcIPs)) + for i, ip := range svcIPs { + if ip.To4() == nil { + rr := new(dns.AAAA) + rr.Hdr.Class = dns.ClassINET + rr.Hdr.Rrtype = dns.TypeAAAA + rr.Hdr.Name = svcNames[i] + rr.AAAA = ip + rrs[i] = rr + continue + } + rr := new(dns.A) + rr.Hdr.Class = dns.ClassINET + rr.Hdr.Rrtype = dns.TypeA + rr.Hdr.Name = svcNames[i] + rr.A = ip + rrs[i] = rr + } + + return rrs +} + +const defaultNSName = "ns.dns." diff --git a/plugin/kubernetes/ns_test.go b/plugin/kubernetes/ns_test.go new file mode 100644 index 0000000..bdf326e --- /dev/null +++ b/plugin/kubernetes/ns_test.go @@ -0,0 +1,219 @@ +package kubernetes + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/coredns/coredns/plugin/kubernetes/object" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" +) + +type APIConnTest struct{} + +func (APIConnTest) HasSynced() bool { return true } +func (APIConnTest) Run() {} +func (APIConnTest) Stop() error { return nil } +func (APIConnTest) PodIndex(string) []*object.Pod { return nil } +func (APIConnTest) SvcIndexReverse(string) []*object.Service { return nil } +func (APIConnTest) SvcExtIndexReverse(string) []*object.Service { return nil } +func (APIConnTest) EpIndex(string) []*object.Endpoints { return nil } +func (APIConnTest) EndpointsList() []*object.Endpoints { return nil } +func (APIConnTest) Modified(bool) int64 { return 0 } + +func (a APIConnTest) SvcIndex(s string) []*object.Service { + switch s { + case "dns-service.kube-system": + return []*object.Service{a.ServiceList()[0]} + case "hdls-dns-service.kube-system": + return []*object.Service{a.ServiceList()[1]} + case "dns6-service.kube-system": + return []*object.Service{a.ServiceList()[2]} + } + return nil +} + +var svcs = []*object.Service{ + { + Name: "dns-service", + Namespace: "kube-system", + ClusterIPs: []string{"10.0.0.111"}, + }, + { + Name: "hdls-dns-service", + Namespace: "kube-system", + ClusterIPs: []string{api.ClusterIPNone}, + }, + { + Name: "dns6-service", + Namespace: "kube-system", + ClusterIPs: []string{"10::111"}, + }, +} + +func (APIConnTest) ServiceList() []*object.Service { + return svcs +} + +func (APIConnTest) EpIndexReverse(ip string) []*object.Endpoints { + if ip != "10.244.0.20" { + return nil + } + eps := []*object.Endpoints{ + { + Name: "dns-service-slice1", + Namespace: "kube-system", + Index: object.EndpointsKey("dns-service", "kube-system"), + Subsets: []object.EndpointSubset{ + {Addresses: []object.EndpointAddress{{IP: "10.244.0.20"}}}, + }, + }, + { + Name: "hdls-dns-service-slice1", + Namespace: "kube-system", + Index: object.EndpointsKey("hdls-dns-service", "kube-system"), + Subsets: []object.EndpointSubset{ + {Addresses: []object.EndpointAddress{{IP: "10.244.0.20"}}}, + }, + }, + { + Name: "dns6-service-slice1", + Namespace: "kube-system", + Index: object.EndpointsKey("dns6-service", "kube-system"), + Subsets: []object.EndpointSubset{ + {Addresses: []object.EndpointAddress{{IP: "10.244.0.20"}}}, + }, + }, + } + return eps +} + +func (APIConnTest) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { + return &api.Node{}, nil +} +func (APIConnTest) GetNamespaceByName(name string) (*object.Namespace, error) { + return nil, fmt.Errorf("namespace not found") +} + +func TestNsAddrs(t *testing.T) { + k := New([]string{"inter.webs.test."}) + k.APIConn = &APIConnTest{} + k.localIPs = []net.IP{net.ParseIP("10.244.0.20")} + + cdrs := k.nsAddrs(false, false, k.Zones[0]) + + if len(cdrs) != 3 { + t.Fatalf("Expected 3 results, got %v", len(cdrs)) + } + cdr := cdrs[0] + expected := "10.0.0.111" + if cdr.(*dns.A).A.String() != expected { + t.Errorf("Expected 1st A to be %q, got %q", expected, cdr.(*dns.A).A.String()) + } + expected = "dns-service.kube-system.svc.inter.webs.test." + if cdr.Header().Name != expected { + t.Errorf("Expected 1st Header Name to be %q, got %q", expected, cdr.Header().Name) + } + cdr = cdrs[1] + expected = "10.244.0.20" + if cdr.(*dns.A).A.String() != expected { + t.Errorf("Expected 2nd A to be %q, got %q", expected, cdr.(*dns.A).A.String()) + } + expected = "10-244-0-20.hdls-dns-service.kube-system.svc.inter.webs.test." + if cdr.Header().Name != expected { + t.Errorf("Expected 2nd Header Name to be %q, got %q", expected, cdr.Header().Name) + } + cdr = cdrs[2] + expected = "10::111" + if cdr.(*dns.AAAA).AAAA.String() != expected { + t.Errorf("Expected AAAA to be %q, got %q", expected, cdr.(*dns.A).A.String()) + } + expected = "dns6-service.kube-system.svc.inter.webs.test." + if cdr.Header().Name != expected { + t.Errorf("Expected AAAA Header Name to be %q, got %q", expected, cdr.Header().Name) + } +} + +func TestNsAddrsExternalHeadless(t *testing.T) { + k := New([]string{"example.com."}) + k.APIConn = &APIConnTest{} + k.localIPs = []net.IP{net.ParseIP("10.244.0.20")} + + // there are only headless services + cdrs := k.nsAddrs(true, true, k.Zones[0]) + + if len(cdrs) != 1 { + t.Fatalf("Expected 0 results, got %v", cdrs) + } + + cdr := cdrs[0] + expected := "10.244.0.20" + if cdr.(*dns.A).A.String() != expected { + t.Errorf("Expected A address to be %q, got %q", expected, cdr.(*dns.A).A.String()) + } + expected = "10-244-0-20.hdls-dns-service.kube-system.example.com." + if cdr.Header().Name != expected { + t.Errorf("Expected record name to be %q, got %q", expected, cdr.Header().Name) + } +} + +func TestNsAddrsExternal(t *testing.T) { + k := New([]string{"example.com."}) + k.APIConn = &APIConnTest{} + k.localIPs = []net.IP{net.ParseIP("10.244.0.20")} + + // initially no services have an external IP ... + cdrs := k.nsAddrs(true, false, k.Zones[0]) + + if len(cdrs) != 0 { + t.Fatalf("Expected 0 results, got %v", len(cdrs)) + } + + // Add an external IP to one of the services ... + svcs[0].ExternalIPs = []string{"1.2.3.4"} + cdrs = k.nsAddrs(true, false, k.Zones[0]) + + if len(cdrs) != 1 { + t.Fatalf("Expected 1 results, got %v", len(cdrs)) + } + cdr := cdrs[0] + expected := "1.2.3.4" + if cdr.(*dns.A).A.String() != expected { + t.Errorf("Expected A address to be %q, got %q", expected, cdr.(*dns.A).A.String()) + } + expected = "dns-service.kube-system.example.com." + if cdr.Header().Name != expected { + t.Errorf("Expected record name to be %q, got %q", expected, cdr.Header().Name) + } +} + +func TestNsAddrsExternalWithPreexistingExternalIP(t *testing.T) { + k := New([]string{"example.com."}) + k.APIConn = &APIConnTest{} + k.localIPs = []net.IP{net.ParseIP("10.244.0.20")} + + svcs[0].ExternalIPs = []string{"1.2.3.4"} + + // initially no services have an external IP ... + cdrs := k.nsAddrs(true, false, k.Zones[0]) + + if len(cdrs) != 1 { + t.Fatalf("Expected 1 results, got %v", len(cdrs)) + } + + if len(cdrs) != 1 { + t.Fatalf("Expected 1 results, got %v", len(cdrs)) + } + cdr := cdrs[0] + expected := "1.2.3.4" + if cdr.(*dns.A).A.String() != expected { + t.Errorf("Expected A address to be %q, got %q", expected, cdr.(*dns.A).A.String()) + } + expected = "dns-service.kube-system.example.com." + if cdr.Header().Name != expected { + t.Errorf("Expected record name to be %q, got %q", expected, cdr.Header().Name) + } +} diff --git a/plugin/kubernetes/object/endpoint.go b/plugin/kubernetes/object/endpoint.go new file mode 100644 index 0000000..26555e1 --- /dev/null +++ b/plugin/kubernetes/object/endpoint.go @@ -0,0 +1,182 @@ +package object + +import ( + "fmt" + + discovery "k8s.io/api/discovery/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// Endpoints is a stripped down api.Endpoints with only the items we need for CoreDNS. +type Endpoints struct { + // Don't add new fields to this struct without talking to the CoreDNS maintainers. + Version string + Name string + Namespace string + Index string + IndexIP []string + Subsets []EndpointSubset + + *Empty +} + +// EndpointSubset is a group of addresses with a common set of ports. The +// expanded set of endpoints is the Cartesian product of Addresses x Ports. +type EndpointSubset struct { + Addresses []EndpointAddress + Ports []EndpointPort +} + +// EndpointAddress is a tuple that describes single IP address. +type EndpointAddress struct { + IP string + Hostname string + NodeName string + TargetRefName string +} + +// EndpointPort is a tuple that describes a single port. +type EndpointPort struct { + Port int32 + Name string + Protocol string +} + +// EndpointsKey returns a string using for the index. +func EndpointsKey(name, namespace string) string { return name + "." + namespace } + +// EndpointSliceToEndpoints converts a *discovery.EndpointSlice to a *Endpoints. +func EndpointSliceToEndpoints(obj meta.Object) (meta.Object, error) { + ends, ok := obj.(*discovery.EndpointSlice) + if !ok { + return nil, fmt.Errorf("unexpected object %v", obj) + } + e := &Endpoints{ + Version: ends.GetResourceVersion(), + Name: ends.GetName(), + Namespace: ends.GetNamespace(), + Index: EndpointsKey(ends.Labels[discovery.LabelServiceName], ends.GetNamespace()), + Subsets: make([]EndpointSubset, 1), + } + + if len(ends.Ports) == 0 { + // Add sentinel if there are no ports. + e.Subsets[0].Ports = []EndpointPort{{Port: -1}} + } else { + e.Subsets[0].Ports = make([]EndpointPort, len(ends.Ports)) + for k, p := range ends.Ports { + port := int32(-1) + name := "" + protocol := "" + if p.Port != nil { + port = *p.Port + } + if p.Name != nil { + name = *p.Name + } + if p.Protocol != nil { + protocol = string(*p.Protocol) + } + ep := EndpointPort{Port: port, Name: name, Protocol: protocol} + e.Subsets[0].Ports[k] = ep + } + } + + for _, end := range ends.Endpoints { + if !endpointsliceReady(end.Conditions.Ready) { + continue + } + for _, a := range end.Addresses { + ea := EndpointAddress{IP: a} + if end.Hostname != nil { + ea.Hostname = *end.Hostname + } + // ignore pod names that are too long to be a valid label + if end.TargetRef != nil && len(end.TargetRef.Name) < 64 { + ea.TargetRefName = end.TargetRef.Name + } + if end.NodeName != nil { + ea.NodeName = *end.NodeName + } + e.Subsets[0].Addresses = append(e.Subsets[0].Addresses, ea) + e.IndexIP = append(e.IndexIP, a) + } + } + + *ends = discovery.EndpointSlice{} + + return e, nil +} + +func endpointsliceReady(ready *bool) bool { + // Per API docs: a nil value indicates an unknown state. In most cases consumers + // should interpret this unknown state as ready. + if ready == nil { + return true + } + return *ready +} + +// CopyWithoutSubsets copies e, without the subsets. +func (e *Endpoints) CopyWithoutSubsets() *Endpoints { + e1 := &Endpoints{ + Version: e.Version, + Name: e.Name, + Namespace: e.Namespace, + Index: e.Index, + IndexIP: make([]string, len(e.IndexIP)), + } + copy(e1.IndexIP, e.IndexIP) + return e1 +} + +var _ runtime.Object = &Endpoints{} + +// DeepCopyObject implements the ObjectKind interface. +func (e *Endpoints) DeepCopyObject() runtime.Object { + e1 := &Endpoints{ + Version: e.Version, + Name: e.Name, + Namespace: e.Namespace, + Index: e.Index, + IndexIP: make([]string, len(e.IndexIP)), + Subsets: make([]EndpointSubset, len(e.Subsets)), + } + copy(e1.IndexIP, e.IndexIP) + + for i, eps := range e.Subsets { + sub := EndpointSubset{ + Addresses: make([]EndpointAddress, len(eps.Addresses)), + Ports: make([]EndpointPort, len(eps.Ports)), + } + for j, a := range eps.Addresses { + ea := EndpointAddress{IP: a.IP, Hostname: a.Hostname, NodeName: a.NodeName, TargetRefName: a.TargetRefName} + sub.Addresses[j] = ea + } + for k, p := range eps.Ports { + ep := EndpointPort{Port: p.Port, Name: p.Name, Protocol: p.Protocol} + sub.Ports[k] = ep + } + e1.Subsets[i] = sub + } + return e1 +} + +// GetNamespace implements the metav1.Object interface. +func (e *Endpoints) GetNamespace() string { return e.Namespace } + +// SetNamespace implements the metav1.Object interface. +func (e *Endpoints) SetNamespace(namespace string) {} + +// GetName implements the metav1.Object interface. +func (e *Endpoints) GetName() string { return e.Name } + +// SetName implements the metav1.Object interface. +func (e *Endpoints) SetName(name string) {} + +// GetResourceVersion implements the metav1.Object interface. +func (e *Endpoints) GetResourceVersion() string { return e.Version } + +// SetResourceVersion implements the metav1.Object interface. +func (e *Endpoints) SetResourceVersion(version string) {} diff --git a/plugin/kubernetes/object/informer.go b/plugin/kubernetes/object/informer.go new file mode 100644 index 0000000..86d872c --- /dev/null +++ b/plugin/kubernetes/object/informer.go @@ -0,0 +1,88 @@ +package object + +import ( + "fmt" + + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/cache" +) + +// NewIndexerInformer is a copy of the cache.NewIndexerInformer function, but allows custom process function +func NewIndexerInformer(lw cache.ListerWatcher, objType runtime.Object, h cache.ResourceEventHandler, indexers cache.Indexers, builder ProcessorBuilder) (cache.Indexer, cache.Controller) { + clientState := cache.NewIndexer(cache.DeletionHandlingMetaNamespaceKeyFunc, indexers) + + cfg := &cache.Config{ + Queue: cache.NewDeltaFIFOWithOptions(cache.DeltaFIFOOptions{KeyFunction: cache.MetaNamespaceKeyFunc, KnownObjects: clientState}), + ListerWatcher: lw, + ObjectType: objType, + FullResyncPeriod: defaultResyncPeriod, + RetryOnError: false, + Process: builder(clientState, h), + } + return clientState, cache.New(cfg) +} + +// RecordLatencyFunc is a function for recording api object delta latency +type RecordLatencyFunc func(meta.Object) + +// DefaultProcessor is based on the Process function from cache.NewIndexerInformer except it does a conversion. +func DefaultProcessor(convert ToFunc, recordLatency *EndpointLatencyRecorder) ProcessorBuilder { + return func(clientState cache.Indexer, h cache.ResourceEventHandler) cache.ProcessFunc { + return func(obj interface{}, isInitialList bool) error { + for _, d := range obj.(cache.Deltas) { + if recordLatency != nil { + if o, ok := d.Object.(meta.Object); ok { + recordLatency.init(o) + } + } + switch d.Type { + case cache.Sync, cache.Added, cache.Updated: + obj, err := convert(d.Object.(meta.Object)) + if err != nil { + return err + } + if old, exists, err := clientState.Get(obj); err == nil && exists { + if err := clientState.Update(obj); err != nil { + return err + } + h.OnUpdate(old, obj) + } else { + if err := clientState.Add(obj); err != nil { + return err + } + h.OnAdd(obj, isInitialList) + } + if recordLatency != nil { + recordLatency.record() + } + case cache.Deleted: + var obj interface{} + obj, ok := d.Object.(cache.DeletedFinalStateUnknown) + if !ok { + var err error + metaObj, ok := d.Object.(meta.Object) + if !ok { + return fmt.Errorf("unexpected object %v", d.Object) + } + obj, err = convert(metaObj) + if err != nil && err != errPodTerminating { + return err + } + } + + if err := clientState.Delete(obj); err != nil { + return err + } + h.OnDelete(obj) + if !ok && recordLatency != nil { + recordLatency.record() + } + } + } + return nil + } + } +} + +const defaultResyncPeriod = 0 diff --git a/plugin/kubernetes/object/metrics.go b/plugin/kubernetes/object/metrics.go new file mode 100644 index 0000000..f39744b --- /dev/null +++ b/plugin/kubernetes/object/metrics.go @@ -0,0 +1,82 @@ +package object + +import ( + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var ( + // DNSProgrammingLatency is defined as the time it took to program a DNS instance - from the time + // a service or pod has changed to the time the change was propagated and was available to be + // served by a DNS server. + // The definition of this SLI can be found at https://github.com/kubernetes/community/blob/master/sig-scalability/slos/dns_programming_latency.md + // Note that the metrics is partially based on the time exported by the endpoints controller on + // the master machine. The measurement may be inaccurate if there is a clock drift between the + // node and master machine. + // The service_kind label can be one of: + // * cluster_ip + // * headless_with_selector + // * headless_without_selector + DNSProgrammingLatency = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "kubernetes", + Name: "dns_programming_duration_seconds", + // From 1 millisecond to ~17 minutes. + Buckets: prometheus.ExponentialBuckets(0.001, 2, 20), + Help: "Histogram of the time (in seconds) it took to program a dns instance.", + }, []string{"service_kind"}) + + // DurationSinceFunc returns the duration elapsed since the given time. + // Added as a global variable to allow injection for testing. + DurationSinceFunc = time.Since +) + +// EndpointLatencyRecorder records latency metric for endpoint objects +type EndpointLatencyRecorder struct { + TT time.Time + ServiceFunc func(meta.Object) []*Service + Services []*Service +} + +func (l *EndpointLatencyRecorder) init(o meta.Object) { + l.Services = l.ServiceFunc(o) + l.TT = time.Time{} + stringVal, ok := o.GetAnnotations()[api.EndpointsLastChangeTriggerTime] + if ok { + tt, err := time.Parse(time.RFC3339Nano, stringVal) + if err != nil { + log.Warningf("DnsProgrammingLatency cannot be calculated for Endpoints '%s/%s'; invalid %q annotation RFC3339 value of %q", + o.GetNamespace(), o.GetName(), api.EndpointsLastChangeTriggerTime, stringVal) + // In case of error val = time.Zero, which is ignored downstream. + } + l.TT = tt + } +} + +func (l *EndpointLatencyRecorder) record() { + // isHeadless indicates whether the endpoints object belongs to a headless + // service (i.e. clusterIp = None). Note that this can be a false negatives if the service + // informer is lagging, i.e. we may not see a recently created service. Given that the services + // don't change very often (comparing to much more frequent endpoints changes), cases when this method + // will return wrong answer should be relatively rare. Because of that we intentionally accept this + // flaw to keep the solution simple. + isHeadless := len(l.Services) == 1 && l.Services[0].Headless() + + if !isHeadless || l.TT.IsZero() { + return + } + + // If we're here it means that the Endpoints object is for a headless service and that + // the Endpoints object was created by the endpoints-controller (because the + // LastChangeTriggerTime annotation is set). It means that the corresponding service is a + // "headless service with selector". + DNSProgrammingLatency.WithLabelValues("headless_with_selector"). + Observe(DurationSinceFunc(l.TT).Seconds()) +} diff --git a/plugin/kubernetes/object/namespace.go b/plugin/kubernetes/object/namespace.go new file mode 100644 index 0000000..ec1b466 --- /dev/null +++ b/plugin/kubernetes/object/namespace.go @@ -0,0 +1,61 @@ +package object + +import ( + "fmt" + + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// Namespace is a stripped down api.Namespace with only the items we need for CoreDNS. +type Namespace struct { + // Don't add new fields to this struct without talking to the CoreDNS maintainers. + Version string + Name string + + *Empty +} + +// ToNamespace returns a function that converts an api.Namespace to a *Namespace. +func ToNamespace(obj meta.Object) (meta.Object, error) { + ns, ok := obj.(*api.Namespace) + if !ok { + return nil, fmt.Errorf("unexpected object %v", obj) + } + n := &Namespace{ + Version: ns.GetResourceVersion(), + Name: ns.GetName(), + } + *ns = api.Namespace{} + return n, nil +} + +var _ runtime.Object = &Namespace{} + +// DeepCopyObject implements the ObjectKind interface. +func (n *Namespace) DeepCopyObject() runtime.Object { + n1 := &Namespace{ + Version: n.Version, + Name: n.Name, + } + return n1 +} + +// GetNamespace implements the metav1.Object interface. +func (n *Namespace) GetNamespace() string { return "" } + +// SetNamespace implements the metav1.Object interface. +func (n *Namespace) SetNamespace(namespace string) {} + +// GetName implements the metav1.Object interface. +func (n *Namespace) GetName() string { return n.Name } + +// SetName implements the metav1.Object interface. +func (n *Namespace) SetName(name string) {} + +// GetResourceVersion implements the metav1.Object interface. +func (n *Namespace) GetResourceVersion() string { return n.Version } + +// SetResourceVersion implements the metav1.Object interface. +func (n *Namespace) SetResourceVersion(version string) {} diff --git a/plugin/kubernetes/object/object.go b/plugin/kubernetes/object/object.go new file mode 100644 index 0000000..3421779 --- /dev/null +++ b/plugin/kubernetes/object/object.go @@ -0,0 +1,113 @@ +// Package object holds functions that convert the objects from the k8s API in +// to a more memory efficient structures. +// +// Adding new fields to any of the structures defined in pod.go, endpoint.go +// and service.go should not be done lightly as this increases the memory use +// and will leads to OOMs in the k8s scale test. +// +// We can do some optimizations here as well. We store IP addresses as strings, +// this might be moved to uint32 (for v4) for instance, but then we need to +// convert those again. +// +// Also the msg.Service use in this plugin may be deprecated at some point, as +// we don't use most of those features anyway and would free us from the *etcd* +// dependency, where msg.Service is defined. And should save some mem/cpu as we +// convert to and from msg.Services. +package object + +import ( + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/cache" +) + +// ToFunc converts one v1.Object to another v1.Object. +type ToFunc func(v1.Object) (v1.Object, error) + +// ProcessorBuilder returns function to process cache events. +type ProcessorBuilder func(cache.Indexer, cache.ResourceEventHandler) cache.ProcessFunc + +// Empty is an empty struct. +type Empty struct{} + +// GetObjectKind implements the ObjectKind interface as a noop. +func (e *Empty) GetObjectKind() schema.ObjectKind { return schema.EmptyObjectKind } + +// GetGenerateName implements the metav1.Object interface. +func (e *Empty) GetGenerateName() string { return "" } + +// SetGenerateName implements the metav1.Object interface. +func (e *Empty) SetGenerateName(name string) {} + +// GetUID implements the metav1.Object interface. +func (e *Empty) GetUID() types.UID { return "" } + +// SetUID implements the metav1.Object interface. +func (e *Empty) SetUID(uid types.UID) {} + +// GetGeneration implements the metav1.Object interface. +func (e *Empty) GetGeneration() int64 { return 0 } + +// SetGeneration implements the metav1.Object interface. +func (e *Empty) SetGeneration(generation int64) {} + +// GetSelfLink implements the metav1.Object interface. +func (e *Empty) GetSelfLink() string { return "" } + +// SetSelfLink implements the metav1.Object interface. +func (e *Empty) SetSelfLink(selfLink string) {} + +// GetCreationTimestamp implements the metav1.Object interface. +func (e *Empty) GetCreationTimestamp() v1.Time { return v1.Time{} } + +// SetCreationTimestamp implements the metav1.Object interface. +func (e *Empty) SetCreationTimestamp(timestamp v1.Time) {} + +// GetDeletionTimestamp implements the metav1.Object interface. +func (e *Empty) GetDeletionTimestamp() *v1.Time { return &v1.Time{} } + +// SetDeletionTimestamp implements the metav1.Object interface. +func (e *Empty) SetDeletionTimestamp(timestamp *v1.Time) {} + +// GetDeletionGracePeriodSeconds implements the metav1.Object interface. +func (e *Empty) GetDeletionGracePeriodSeconds() *int64 { return nil } + +// SetDeletionGracePeriodSeconds implements the metav1.Object interface. +func (e *Empty) SetDeletionGracePeriodSeconds(*int64) {} + +// GetLabels implements the metav1.Object interface. +func (e *Empty) GetLabels() map[string]string { return nil } + +// SetLabels implements the metav1.Object interface. +func (e *Empty) SetLabels(labels map[string]string) {} + +// GetAnnotations implements the metav1.Object interface. +func (e *Empty) GetAnnotations() map[string]string { return nil } + +// SetAnnotations implements the metav1.Object interface. +func (e *Empty) SetAnnotations(annotations map[string]string) {} + +// GetFinalizers implements the metav1.Object interface. +func (e *Empty) GetFinalizers() []string { return nil } + +// SetFinalizers implements the metav1.Object interface. +func (e *Empty) SetFinalizers(finalizers []string) {} + +// GetOwnerReferences implements the metav1.Object interface. +func (e *Empty) GetOwnerReferences() []v1.OwnerReference { return nil } + +// SetOwnerReferences implements the metav1.Object interface. +func (e *Empty) SetOwnerReferences([]v1.OwnerReference) {} + +// GetZZZ_DeprecatedClusterName implements the metav1.Object interface. +func (e *Empty) GetZZZ_DeprecatedClusterName() string { return "" } + +// SetZZZ_DeprecatedClusterName implements the metav1.Object interface. +func (e *Empty) SetZZZ_DeprecatedClusterName(clusterName string) {} + +// GetManagedFields implements the metav1.Object interface. +func (e *Empty) GetManagedFields() []v1.ManagedFieldsEntry { return nil } + +// SetManagedFields implements the metav1.Object interface. +func (e *Empty) SetManagedFields(managedFields []v1.ManagedFieldsEntry) {} diff --git a/plugin/kubernetes/object/pod.go b/plugin/kubernetes/object/pod.go new file mode 100644 index 0000000..9b9d564 --- /dev/null +++ b/plugin/kubernetes/object/pod.go @@ -0,0 +1,78 @@ +package object + +import ( + "errors" + "fmt" + + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// Pod is a stripped down api.Pod with only the items we need for CoreDNS. +type Pod struct { + // Don't add new fields to this struct without talking to the CoreDNS maintainers. + Version string + PodIP string + Name string + Namespace string + + *Empty +} + +var errPodTerminating = errors.New("pod terminating") + +// ToPod converts an api.Pod to a *Pod. +func ToPod(obj meta.Object) (meta.Object, error) { + apiPod, ok := obj.(*api.Pod) + if !ok { + return nil, fmt.Errorf("unexpected object %v", obj) + } + pod := &Pod{ + Version: apiPod.GetResourceVersion(), + PodIP: apiPod.Status.PodIP, + Namespace: apiPod.GetNamespace(), + Name: apiPod.GetName(), + } + t := apiPod.ObjectMeta.DeletionTimestamp + if t != nil && !(*t).Time.IsZero() { + // if the pod is in the process of termination, return an error so it can be ignored + // during add/update event processing + return pod, errPodTerminating + } + + *apiPod = api.Pod{} + + return pod, nil +} + +var _ runtime.Object = &Pod{} + +// DeepCopyObject implements the ObjectKind interface. +func (p *Pod) DeepCopyObject() runtime.Object { + p1 := &Pod{ + Version: p.Version, + PodIP: p.PodIP, + Namespace: p.Namespace, + Name: p.Name, + } + return p1 +} + +// GetNamespace implements the metav1.Object interface. +func (p *Pod) GetNamespace() string { return p.Namespace } + +// SetNamespace implements the metav1.Object interface. +func (p *Pod) SetNamespace(namespace string) {} + +// GetName implements the metav1.Object interface. +func (p *Pod) GetName() string { return p.Name } + +// SetName implements the metav1.Object interface. +func (p *Pod) SetName(name string) {} + +// GetResourceVersion implements the metav1.Object interface. +func (p *Pod) GetResourceVersion() string { return p.Version } + +// SetResourceVersion implements the metav1.Object interface. +func (p *Pod) SetResourceVersion(version string) {} diff --git a/plugin/kubernetes/object/service.go b/plugin/kubernetes/object/service.go new file mode 100644 index 0000000..bd3e3d3 --- /dev/null +++ b/plugin/kubernetes/object/service.go @@ -0,0 +1,120 @@ +package object + +import ( + "fmt" + + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" +) + +// Service is a stripped down api.Service with only the items we need for CoreDNS. +type Service struct { + // Don't add new fields to this struct without talking to the CoreDNS maintainers. + Version string + Name string + Namespace string + Index string + ClusterIPs []string + Type api.ServiceType + ExternalName string + Ports []api.ServicePort + + // ExternalIPs we may want to export. + ExternalIPs []string + + *Empty +} + +// ServiceKey returns a string using for the index. +func ServiceKey(name, namespace string) string { return name + "." + namespace } + +// ToService converts an api.Service to a *Service. +func ToService(obj meta.Object) (meta.Object, error) { + svc, ok := obj.(*api.Service) + if !ok { + return nil, fmt.Errorf("unexpected object %v", obj) + } + s := &Service{ + Version: svc.GetResourceVersion(), + Name: svc.GetName(), + Namespace: svc.GetNamespace(), + Index: ServiceKey(svc.GetName(), svc.GetNamespace()), + Type: svc.Spec.Type, + ExternalName: svc.Spec.ExternalName, + + ExternalIPs: make([]string, len(svc.Status.LoadBalancer.Ingress)+len(svc.Spec.ExternalIPs)), + } + + if len(svc.Spec.ClusterIPs) > 0 { + s.ClusterIPs = make([]string, len(svc.Spec.ClusterIPs)) + copy(s.ClusterIPs, svc.Spec.ClusterIPs) + } else { + s.ClusterIPs = []string{svc.Spec.ClusterIP} + } + + if len(svc.Spec.Ports) == 0 { + // Add sentinel if there are no ports. + s.Ports = []api.ServicePort{{Port: -1}} + } else { + s.Ports = make([]api.ServicePort, len(svc.Spec.Ports)) + copy(s.Ports, svc.Spec.Ports) + } + + li := copy(s.ExternalIPs, svc.Spec.ExternalIPs) + for i, lb := range svc.Status.LoadBalancer.Ingress { + if lb.IP != "" { + s.ExternalIPs[li+i] = lb.IP + continue + } + s.ExternalIPs[li+i] = lb.Hostname + } + + *svc = api.Service{} + + return s, nil +} + +// Headless returns true if the service is headless +func (s *Service) Headless() bool { + return s.ClusterIPs[0] == api.ClusterIPNone +} + +var _ runtime.Object = &Service{} + +// DeepCopyObject implements the ObjectKind interface. +func (s *Service) DeepCopyObject() runtime.Object { + s1 := &Service{ + Version: s.Version, + Name: s.Name, + Namespace: s.Namespace, + Index: s.Index, + Type: s.Type, + ExternalName: s.ExternalName, + ClusterIPs: make([]string, len(s.ClusterIPs)), + Ports: make([]api.ServicePort, len(s.Ports)), + ExternalIPs: make([]string, len(s.ExternalIPs)), + } + copy(s1.ClusterIPs, s.ClusterIPs) + copy(s1.Ports, s.Ports) + copy(s1.ExternalIPs, s.ExternalIPs) + return s1 +} + +// GetNamespace implements the metav1.Object interface. +func (s *Service) GetNamespace() string { return s.Namespace } + +// SetNamespace implements the metav1.Object interface. +func (s *Service) SetNamespace(namespace string) {} + +// GetName implements the metav1.Object interface. +func (s *Service) GetName() string { return s.Name } + +// SetName implements the metav1.Object interface. +func (s *Service) SetName(name string) {} + +// GetResourceVersion implements the metav1.Object interface. +func (s *Service) GetResourceVersion() string { return s.Version } + +// SetResourceVersion implements the metav1.Object interface. +func (s *Service) SetResourceVersion(version string) {} diff --git a/plugin/kubernetes/parse.go b/plugin/kubernetes/parse.go new file mode 100644 index 0000000..4690c81 --- /dev/null +++ b/plugin/kubernetes/parse.go @@ -0,0 +1,103 @@ +package kubernetes + +import ( + "github.com/coredns/coredns/plugin/pkg/dnsutil" + + "github.com/miekg/dns" +) + +type recordRequest struct { + // The named port from the kubernetes DNS spec, this is the service part (think _https) from a well formed + // SRV record. + port string + // The protocol is usually _udp or _tcp (if set), and comes from the protocol part of a well formed + // SRV record. + protocol string + endpoint string + // The servicename used in Kubernetes. + service string + // The namespace used in Kubernetes. + namespace string + // A each name can be for a pod or a service, here we track what we've seen, either "pod" or "service". + podOrSvc string +} + +// parseRequest parses the qname to find all the elements we need for querying k8s. Anything +// that is not parsed will have the wildcard "*" value (except r.endpoint). +// Potential underscores are stripped from _port and _protocol. +func parseRequest(name, zone string) (r recordRequest, err error) { + // 3 Possible cases: + // 1. _port._protocol.service.namespace.pod|svc.zone + // 2. (endpoint): endpoint.service.namespace.pod|svc.zone + // 3. (service): service.namespace.pod|svc.zone + + base, _ := dnsutil.TrimZone(name, zone) + // return NODATA for apex queries + if base == "" || base == Svc || base == Pod { + return r, nil + } + segs := dns.SplitDomainName(base) + + last := len(segs) - 1 + if last < 0 { + return r, nil + } + r.podOrSvc = segs[last] + if r.podOrSvc != Pod && r.podOrSvc != Svc { + return r, errInvalidRequest + } + last-- + if last < 0 { + return r, nil + } + + r.namespace = segs[last] + last-- + if last < 0 { + return r, nil + } + + r.service = segs[last] + last-- + if last < 0 { + return r, nil + } + + // Because of ambiguity we check the labels left: 1: an endpoint. 2: port and protocol. + // Anything else is a query that is too long to answer and can safely be delegated to return an nxdomain. + switch last { + case 0: // endpoint only + r.endpoint = segs[last] + case 1: // service and port + r.protocol = stripUnderscore(segs[last]) + r.port = stripUnderscore(segs[last-1]) + + default: // too long + return r, errInvalidRequest + } + + return r, nil +} + +// stripUnderscore removes a prefixed underscore from s. +func stripUnderscore(s string) string { + if len(s) == 0 { + return s + } + if s[0] != '_' { + return s + } + return s[1:] +} + +// String returns a string representation of r, it just returns all fields concatenated with dots. +// This is mostly used in tests. +func (r recordRequest) String() string { + s := r.port + s += "." + r.protocol + s += "." + r.endpoint + s += "." + r.service + s += "." + r.namespace + s += "." + r.podOrSvc + return s +} diff --git a/plugin/kubernetes/parse_test.go b/plugin/kubernetes/parse_test.go new file mode 100644 index 0000000..739a405 --- /dev/null +++ b/plugin/kubernetes/parse_test.go @@ -0,0 +1,62 @@ +package kubernetes + +import ( + "testing" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestParseRequest(t *testing.T) { + tests := []struct { + query string + expected string // output from r.String() + }{ + // valid SRV request + {"_http._tcp.webs.mynamespace.svc.inter.webs.tests.", "http.tcp..webs.mynamespace.svc"}, + // A request of endpoint + {"1-2-3-4.webs.mynamespace.svc.inter.webs.tests.", "..1-2-3-4.webs.mynamespace.svc"}, + // bare zone + {"inter.webs.tests.", "....."}, + // bare svc type + {"svc.inter.webs.tests.", "....."}, + // bare pod type + {"pod.inter.webs.tests.", "....."}, + // SRV request with empty segments + {"..webs.mynamespace.svc.inter.webs.tests.", "...webs.mynamespace.svc"}, + } + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.query, dns.TypeA) + state := request.Request{Zone: zone, Req: m} + + r, e := parseRequest(state.Name(), state.Zone) + if e != nil { + t.Errorf("Test %d, expected no error, got '%v'.", i, e) + } + rs := r.String() + if rs != tc.expected { + t.Errorf("Test %d, expected (stringified) recordRequest: %s, got %s", i, tc.expected, rs) + } + } +} + +func TestParseInvalidRequest(t *testing.T) { + invalid := []string{ + "webs.mynamespace.pood.inter.webs.test.", // Request must be for pod or svc subdomain. + "too.long.for.what.I.am.trying.to.pod.inter.webs.tests.", // Too long. + } + + for i, query := range invalid { + m := new(dns.Msg) + m.SetQuestion(query, dns.TypeA) + state := request.Request{Zone: zone, Req: m} + + if _, e := parseRequest(state.Name(), state.Zone); e == nil { + t.Errorf("Test %d: expected error from %s, got none", i, query) + } + } +} + +const zone = "inter.webs.tests." diff --git a/plugin/kubernetes/ready.go b/plugin/kubernetes/ready.go new file mode 100644 index 0000000..2625f3b --- /dev/null +++ b/plugin/kubernetes/ready.go @@ -0,0 +1,4 @@ +package kubernetes + +// Ready implements the ready.Readiness interface. +func (k *Kubernetes) Ready() bool { return k.APIConn.HasSynced() } diff --git a/plugin/kubernetes/reverse.go b/plugin/kubernetes/reverse.go new file mode 100644 index 0000000..26fc3b4 --- /dev/null +++ b/plugin/kubernetes/reverse.go @@ -0,0 +1,55 @@ +package kubernetes + +import ( + "context" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/request" +) + +// Reverse implements the ServiceBackend interface. +func (k *Kubernetes) Reverse(ctx context.Context, state request.Request, exact bool, opt plugin.Options) ([]msg.Service, error) { + ip := dnsutil.ExtractAddressFromReverse(state.Name()) + if ip == "" { + _, e := k.Records(ctx, state, exact) + return nil, e + } + + records := k.serviceRecordForIP(ip, state.Name()) + if len(records) == 0 { + return records, errNoItems + } + return records, nil +} + +// serviceRecordForIP gets a service record with a cluster ip matching the ip argument +// If a service cluster ip does not match, it checks all endpoints +func (k *Kubernetes) serviceRecordForIP(ip, name string) []msg.Service { + // First check services with cluster ips + for _, service := range k.APIConn.SvcIndexReverse(ip) { + if len(k.Namespaces) > 0 && !k.namespaceExposed(service.Namespace) { + continue + } + domain := strings.Join([]string{service.Name, service.Namespace, Svc, k.primaryZone()}, ".") + return []msg.Service{{Host: domain, TTL: k.ttl}} + } + // If no cluster ips match, search endpoints + var svcs []msg.Service + for _, ep := range k.APIConn.EpIndexReverse(ip) { + if len(k.Namespaces) > 0 && !k.namespaceExposed(ep.Namespace) { + continue + } + for _, eps := range ep.Subsets { + for _, addr := range eps.Addresses { + if addr.IP == ip { + domain := strings.Join([]string{endpointHostname(addr, k.endpointNameMode), ep.Index, Svc, k.primaryZone()}, ".") + svcs = append(svcs, msg.Service{Host: domain, TTL: k.ttl}) + } + } + } + } + return svcs +} diff --git a/plugin/kubernetes/reverse_test.go b/plugin/kubernetes/reverse_test.go new file mode 100644 index 0000000..370b9f9 --- /dev/null +++ b/plugin/kubernetes/reverse_test.go @@ -0,0 +1,256 @@ +package kubernetes + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/kubernetes/object" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type APIConnReverseTest struct{} + +func (APIConnReverseTest) HasSynced() bool { return true } +func (APIConnReverseTest) Run() {} +func (APIConnReverseTest) Stop() error { return nil } +func (APIConnReverseTest) PodIndex(string) []*object.Pod { return nil } +func (APIConnReverseTest) EpIndex(string) []*object.Endpoints { return nil } +func (APIConnReverseTest) EndpointsList() []*object.Endpoints { return nil } +func (APIConnReverseTest) ServiceList() []*object.Service { return nil } +func (APIConnReverseTest) SvcExtIndexReverse(string) []*object.Service { return nil } +func (APIConnReverseTest) Modified(bool) int64 { return 0 } + +func (APIConnReverseTest) SvcIndex(svc string) []*object.Service { + if svc != "svc1.testns" { + return nil + } + svcs := []*object.Service{ + { + Name: "svc1", + Namespace: "testns", + ClusterIPs: []string{"192.168.1.100"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + } + return svcs +} + +func (APIConnReverseTest) SvcIndexReverse(ip string) []*object.Service { + if ip != "192.168.1.100" { + return nil + } + svcs := []*object.Service{ + { + Name: "svc1", + Namespace: "testns", + ClusterIPs: []string{"192.168.1.100"}, + Ports: []api.ServicePort{{Name: "http", Protocol: "tcp", Port: 80}}, + }, + } + return svcs +} + +func (APIConnReverseTest) EpIndexReverse(ip string) []*object.Endpoints { + ep1s1 := object.Endpoints{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "10.0.0.100", Hostname: "ep1a"}, + {IP: "10.0.0.99", Hostname: "double-ep"}, // this endpoint is used by two services + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc1-slice1", + Namespace: "testns", + Index: object.EndpointsKey("svc1", "testns"), + } + ep1s2 := object.Endpoints{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "1234:abcd::1", Hostname: "ep1b"}, + {IP: "fd00:77:30::a", Hostname: "ip6svc1ex"}, + {IP: "fd00:77:30::2:9ba6", Hostname: "ip6svc1in"}, + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc1-slice2", + Namespace: "testns", + Index: object.EndpointsKey("svc1", "testns"), + } + ep1s3 := object.Endpoints{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "10.0.0.100", Hostname: "ep1a"}, // duplicate endpointslice address + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc1-ccccc", + Namespace: "testns", + Index: object.EndpointsKey("svc1", "testns"), + } + ep2 := object.Endpoints{ + Subsets: []object.EndpointSubset{ + { + Addresses: []object.EndpointAddress{ + {IP: "10.0.0.99", Hostname: "double-ep"}, // this endpoint is used by two services + }, + Ports: []object.EndpointPort{ + {Port: 80, Protocol: "tcp", Name: "http"}, + }, + }, + }, + Name: "svc2-slice1", + Namespace: "testns", + Index: object.EndpointsKey("svc2", "testns"), + } + switch ip { + case "1234:abcd::1": + fallthrough + case "fd00:77:30::a": + fallthrough + case "fd00:77:30::2:9ba6": + return []*object.Endpoints{&ep1s2} + case "10.0.0.100": // two EndpointSlices for a Service contain this IP (EndpointSlice skew) + return []*object.Endpoints{&ep1s1, &ep1s3} + case "10.0.0.99": // two different Services select this IP + return []*object.Endpoints{&ep1s1, &ep2} + } + return nil +} + +func (APIConnReverseTest) GetNodeByName(ctx context.Context, name string) (*api.Node, error) { + return &api.Node{ + ObjectMeta: meta.ObjectMeta{ + Name: "test.node.foo.bar", + }, + }, nil +} + +func (APIConnReverseTest) GetNamespaceByName(name string) (*object.Namespace, error) { + return &object.Namespace{ + Name: name, + }, nil +} + +func TestReverse(t *testing.T) { + k := New([]string{"cluster.local.", "0.10.in-addr.arpa.", "168.192.in-addr.arpa.", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.d.c.b.a.4.3.2.1.ip6.arpa.", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.3.0.0.7.7.0.0.0.0.d.f.ip6.arpa."}) + k.APIConn = &APIConnReverseTest{} + + tests := []test.Case{ + { + Qname: "100.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("100.0.0.10.in-addr.arpa. 5 IN PTR ep1a.svc1.testns.svc.cluster.local."), + }, + }, + { + Qname: "100.1.168.192.in-addr.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("100.1.168.192.in-addr.arpa. 5 IN PTR svc1.testns.svc.cluster.local."), + }, + }, + { // A PTR record query for an existing ipv6 endpoint should return a record + Qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.d.c.b.a.4.3.2.1.ip6.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.d.c.b.a.4.3.2.1.ip6.arpa. 5 IN PTR ep1b.svc1.testns.svc.cluster.local."), + }, + }, + { // A PTR record query for an existing ipv6 endpoint should return a record + Qname: "a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.3.0.0.7.7.0.0.0.0.d.f.ip6.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.3.0.0.7.7.0.0.0.0.d.f.ip6.arpa. 5 IN PTR ip6svc1ex.svc1.testns.svc.cluster.local."), + }, + }, + { // A PTR record query for an existing ipv6 endpoint should return a record + Qname: "6.a.b.9.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.3.0.0.7.7.0.0.0.0.d.f.ip6.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("6.a.b.9.2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.3.0.0.7.7.0.0.0.0.d.f.ip6.arpa. 5 IN PTR ip6svc1in.svc1.testns.svc.cluster.local."), + }, + }, + { + Qname: "101.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("0.10.in-addr.arpa. 5 IN SOA ns.dns.0.10.in-addr.arpa. hostmaster.0.10.in-addr.arpa. 1502782828 7200 1800 86400 5"), + }, + }, + { + Qname: "example.org.cluster.local.", Qtype: dns.TypePTR, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1502989566 7200 1800 86400 5"), + }, + }, + { + Qname: "svc1.testns.svc.cluster.local.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1502989566 7200 1800 86400 5"), + }, + }, + { + Qname: "svc1.testns.svc.0.10.in-addr.arpa.", Qtype: dns.TypeA, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("0.10.in-addr.arpa. 5 IN SOA ns.dns.0.10.in-addr.arpa. hostmaster.0.10.in-addr.arpa. 1502989566 7200 1800 86400 5"), + }, + }, + { + Qname: "100.0.0.10.cluster.local.", Qtype: dns.TypePTR, + Rcode: dns.RcodeNameError, + Ns: []dns.RR{ + test.SOA("cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 1502989566 7200 1800 86400 5"), + }, + }, + { + Qname: "99.0.0.10.in-addr.arpa.", Qtype: dns.TypePTR, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.PTR("99.0.0.10.in-addr.arpa. 5 IN PTR double-ep.svc1.testns.svc.cluster.local."), + test.PTR("99.0.0.10.in-addr.arpa. 5 IN PTR double-ep.svc2.testns.svc.cluster.local."), + }, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + r := tc.Msg() + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := k.ServeDNS(ctx, w, r) + if err != tc.Error { + t.Errorf("Test %d: expected no error, got %v", i, err) + return + } + + resp := w.Msg + if resp == nil { + t.Fatalf("Test %d: got nil message and no error for: %s %d", i, r.Question[0].Name, r.Question[0].Qtype) + } + if err := test.SortAndCheck(resp, tc); err != nil { + t.Error(err) + } + } +} diff --git a/plugin/kubernetes/setup.go b/plugin/kubernetes/setup.go new file mode 100644 index 0000000..0b988a9 --- /dev/null +++ b/plugin/kubernetes/setup.go @@ -0,0 +1,251 @@ +package kubernetes + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/upstream" + + "github.com/go-logr/logr" + "github.com/miekg/dns" + meta "k8s.io/apimachinery/pkg/apis/meta/v1" + _ "k8s.io/client-go/plugin/pkg/client/auth/oidc" // pull this in here, because we want it excluded if plugin.cfg doesn't have k8s + "k8s.io/client-go/tools/clientcmd" + "k8s.io/klog/v2" +) + +const pluginName = "kubernetes" + +var log = clog.NewWithPlugin(pluginName) + +func init() { plugin.Register(pluginName, setup) } + +func setup(c *caddy.Controller) error { + // Do not call klog.InitFlags(nil) here. It will cause reload to panic. + klog.SetLogger(logr.New(&loggerAdapter{log})) + + k, err := kubernetesParse(c) + if err != nil { + return plugin.Error(pluginName, err) + } + + onStart, onShut, err := k.InitKubeCache(context.Background()) + if err != nil { + return plugin.Error(pluginName, err) + } + if onStart != nil { + c.OnStartup(onStart) + } + if onShut != nil { + c.OnShutdown(onShut) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + k.Next = next + return k + }) + + // get locally bound addresses + c.OnStartup(func() error { + k.localIPs = boundIPs(c) + return nil + }) + + return nil +} + +func kubernetesParse(c *caddy.Controller) (*Kubernetes, error) { + var ( + k8s *Kubernetes + err error + ) + + i := 0 + for c.Next() { + if i > 0 { + return nil, plugin.ErrOnce + } + i++ + + k8s, err = ParseStanza(c) + if err != nil { + return k8s, err + } + } + return k8s, nil +} + +// ParseStanza parses a kubernetes stanza +func ParseStanza(c *caddy.Controller) (*Kubernetes, error) { + k8s := New([]string{""}) + k8s.autoPathSearch = searchFromResolvConf() + + opts := dnsControlOpts{ + initEndpointsCache: true, + ignoreEmptyService: false, + } + k8s.opts = opts + + k8s.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + + k8s.primaryZoneIndex = -1 + for i, z := range k8s.Zones { + if dnsutil.IsReverse(z) > 0 { + continue + } + k8s.primaryZoneIndex = i + break + } + + if k8s.primaryZoneIndex == -1 { + return nil, errors.New("non-reverse zone name must be used") + } + + k8s.Upstream = upstream.New() + + for c.NextBlock() { + switch c.Val() { + case "endpoint_pod_names": + args := c.RemainingArgs() + if len(args) > 0 { + return nil, c.ArgErr() + } + k8s.endpointNameMode = true + continue + case "pods": + args := c.RemainingArgs() + if len(args) == 1 { + switch args[0] { + case podModeDisabled, podModeInsecure, podModeVerified: + k8s.podMode = args[0] + default: + return nil, fmt.Errorf("wrong value for pods: %s, must be one of: disabled, verified, insecure", args[0]) + } + continue + } + return nil, c.ArgErr() + case "namespaces": + args := c.RemainingArgs() + if len(args) > 0 { + for _, a := range args { + k8s.Namespaces[a] = struct{}{} + } + continue + } + return nil, c.ArgErr() + case "endpoint": + args := c.RemainingArgs() + if len(args) > 0 { + // Multiple endpoints are deprecated but still could be specified, + // only the first one be used, though + k8s.APIServerList = args + if len(args) > 1 { + log.Warningf("Multiple endpoints have been deprecated, only the first specified endpoint '%s' is used", args[0]) + } + continue + } + return nil, c.ArgErr() + case "tls": // cert key cacertfile + args := c.RemainingArgs() + if len(args) == 3 { + k8s.APIClientCert, k8s.APIClientKey, k8s.APICertAuth = args[0], args[1], args[2] + continue + } + return nil, c.ArgErr() + case "labels": + args := c.RemainingArgs() + if len(args) > 0 { + labelSelectorString := strings.Join(args, " ") + ls, err := meta.ParseToLabelSelector(labelSelectorString) + if err != nil { + return nil, fmt.Errorf("unable to parse label selector value: '%v': %v", labelSelectorString, err) + } + k8s.opts.labelSelector = ls + continue + } + return nil, c.ArgErr() + case "namespace_labels": + args := c.RemainingArgs() + if len(args) > 0 { + namespaceLabelSelectorString := strings.Join(args, " ") + nls, err := meta.ParseToLabelSelector(namespaceLabelSelectorString) + if err != nil { + return nil, fmt.Errorf("unable to parse namespace_label selector value: '%v': %v", namespaceLabelSelectorString, err) + } + k8s.opts.namespaceLabelSelector = nls + continue + } + return nil, c.ArgErr() + case "fallthrough": + k8s.Fall.SetZonesFromArgs(c.RemainingArgs()) + case "ttl": + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + t, err := strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + if t < 0 || t > 3600 { + return nil, c.Errf("ttl must be in range [0, 3600]: %d", t) + } + k8s.ttl = uint32(t) + case "noendpoints": + if len(c.RemainingArgs()) != 0 { + return nil, c.ArgErr() + } + k8s.opts.initEndpointsCache = false + case "ignore": + args := c.RemainingArgs() + if len(args) > 0 { + ignore := args[0] + if ignore == "empty_service" { + k8s.opts.ignoreEmptyService = true + continue + } else { + return nil, fmt.Errorf("unable to parse ignore value: '%v'", ignore) + } + } + case "kubeconfig": + args := c.RemainingArgs() + if len(args) != 1 && len(args) != 2 { + return nil, c.ArgErr() + } + overrides := &clientcmd.ConfigOverrides{} + if len(args) == 2 { + overrides.CurrentContext = args[1] + } + config := clientcmd.NewNonInteractiveDeferredLoadingClientConfig( + &clientcmd.ClientConfigLoadingRules{ExplicitPath: args[0]}, + overrides, + ) + k8s.ClientConfig = config + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + + if len(k8s.Namespaces) != 0 && k8s.opts.namespaceLabelSelector != nil { + return nil, c.Errf("namespaces and namespace_labels cannot both be set") + } + + return k8s, nil +} + +func searchFromResolvConf() []string { + rc, err := dns.ClientConfigFromFile("/etc/resolv.conf") + if err != nil { + return nil + } + plugin.Zones(rc.Search).Normalize() + return rc.Search +} diff --git a/plugin/kubernetes/setup_test.go b/plugin/kubernetes/setup_test.go new file mode 100644 index 0000000..52b0d3f --- /dev/null +++ b/plugin/kubernetes/setup_test.go @@ -0,0 +1,612 @@ +package kubernetes + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/fall" + + meta "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestKubernetesParse(t *testing.T) { + tests := []struct { + input string // Corefile data as string + shouldErr bool // true if test case is expected to produce an error. + expectedErrContent string // substring from the expected error. Empty for positive cases. + expectedZoneCount int // expected count of defined zones. + expectedNSCount int // expected count of namespaces. + expectedLabelSelector string // expected label selector value + expectedNamespaceLabelSelector string // expected namespace label selector value + expectedPodMode string + expectedFallthrough fall.F + }{ + // positive + { + `kubernetes coredns.local`, + false, + "", + 1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local test.local`, + false, + "", + 2, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { +}`, + false, + "", + 1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + endpoint http://localhost:9090 http://localhost:9091 +}`, + false, + "", + 1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + namespaces demo +}`, + false, + "", + 1, + 1, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + namespaces demo test +}`, + false, + "", + 1, + 2, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + labels environment=prod +}`, + false, + "", + 1, + 0, + "environment=prod", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + labels environment in (production, staging, qa),application=nginx +}`, + false, + "", + 1, + 0, + "application=nginx,environment in (production,qa,staging)", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + namespace_labels istio-injection=enabled +}`, + false, + "", + 1, + 0, + "", + "istio-injection=enabled", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + namespaces foo bar + namespace_labels istio-injection=enabled +}`, + true, + "Error during parsing: namespaces and namespace_labels cannot both be set", + -1, + 0, + "", + "istio-injection=enabled", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local test.local { + endpoint http://localhost:8080 + namespaces demo test + labels environment in (production, staging, qa),application=nginx + fallthrough +}`, + false, + "", + 2, + 2, + "application=nginx,environment in (production,qa,staging)", + "", + podModeDisabled, + fall.Root, + }, + // negative + { + `kubernetes coredns.local { + endpoint +}`, + true, + "rong argument count or unexpected line ending", + -1, + -1, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + namespaces +}`, + true, + "rong argument count or unexpected line ending", + -1, + -1, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + labels +}`, + true, + "rong argument count or unexpected line ending", + -1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + labels environment in (production, qa +}`, + true, + "unable to parse label selector", + -1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + // pods disabled + { + `kubernetes coredns.local { + pods disabled +}`, + false, + "", + 1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + // pods insecure + { + `kubernetes coredns.local { + pods insecure +}`, + false, + "", + 1, + 0, + "", + "", + podModeInsecure, + fall.Zero, + }, + // pods verified + { + `kubernetes coredns.local { + pods verified +}`, + false, + "", + 1, + 0, + "", + "", + podModeVerified, + fall.Zero, + }, + // pods invalid + { + `kubernetes coredns.local { + pods giant_seed +}`, + true, + "rong value for pods", + -1, + 0, + "", + "", + podModeVerified, + fall.Zero, + }, + // fallthrough with zones + { + `kubernetes coredns.local { + fallthrough ip6.arpa inaddr.arpa foo.com +}`, + false, + "rong argument count", + 1, + 0, + "", + "", + podModeDisabled, + fall.F{Zones: []string{"ip6.arpa.", "inaddr.arpa.", "foo.com."}}, + }, + // More than one Kubernetes not allowed + { + `kubernetes coredns.local +kubernetes cluster.local`, + true, + "this plugin", + -1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + kubeconfig +}`, + true, + "Wrong argument count or unexpected line ending after", + -1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + kubeconfig file context extraarg +}`, + true, + "Wrong argument count or unexpected line ending after", + -1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + kubeconfig file +}`, + false, + "", + 1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + { + `kubernetes coredns.local { + kubeconfig file context +}`, + false, + "", + 1, + 0, + "", + "", + podModeDisabled, + fall.Zero, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + k8sController, err := kubernetesParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but did not find error for input '%s'. Error was: '%v'", i, test.input, err) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + continue + } + + if test.shouldErr && (len(test.expectedErrContent) < 1) { + t.Fatalf("Test %d: Test marked as expecting an error, but no expectedErrContent provided for input '%s'. Error was: '%v'", i, test.input, err) + } + + if test.shouldErr && (test.expectedZoneCount >= 0) { + t.Errorf("Test %d: Test marked as expecting an error, but provides value for expectedZoneCount!=-1 for input '%s'. Error was: '%v'", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + continue + } + + // No error was raised, so validate initialization of k8sController + // Zones + foundZoneCount := len(k8sController.Zones) + if foundZoneCount != test.expectedZoneCount { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with %d zones, instead found %d zones: '%v' for input '%s'", i, test.expectedZoneCount, foundZoneCount, k8sController.Zones, test.input) + } + + // Namespaces + foundNSCount := len(k8sController.Namespaces) + if foundNSCount != test.expectedNSCount { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with %d namespaces. Instead found %d namespaces: '%v' for input '%s'", i, test.expectedNSCount, foundNSCount, k8sController.Namespaces, test.input) + } + + // Labels + if k8sController.opts.labelSelector != nil { + foundLabelSelectorString := meta.FormatLabelSelector(k8sController.opts.labelSelector) + if foundLabelSelectorString != test.expectedLabelSelector { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with label selector '%s'. Instead found selector '%s' for input '%s'", i, test.expectedLabelSelector, foundLabelSelectorString, test.input) + } + } + // Pods + foundPodMode := k8sController.podMode + if foundPodMode != test.expectedPodMode { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with pod mode '%s'. Instead found pod mode '%s' for input '%s'", i, test.expectedPodMode, foundPodMode, test.input) + } + + // fallthrough + if !k8sController.Fall.Equal(test.expectedFallthrough) { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with fallthrough '%v'. Instead found fallthrough '%v' for input '%s'", i, test.expectedFallthrough, k8sController.Fall, test.input) + } + } +} + +func TestKubernetesParseEndpointPodNames(t *testing.T) { + tests := []struct { + input string // Corefile data as string + shouldErr bool // true if test case is expected to produce an error. + expectedErrContent string // substring from the expected error. Empty for positive cases. + expectedEndpointMode bool + }{ + // valid endpoints mode + { + `kubernetes coredns.local { + endpoint_pod_names +}`, + false, + "", + true, + }, + // endpoints invalid + { + `kubernetes coredns.local { + endpoint_pod_names giant_seed +}`, + true, + "rong argument count or unexpected", + false, + }, + // endpoint not set + { + `kubernetes coredns.local { +}`, + false, + "", + false, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + k8sController, err := kubernetesParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but did not find error for input '%s'. Error was: '%v'", i, test.input, err) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + continue + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + continue + } + + // Endpoints + foundEndpointNameMode := k8sController.endpointNameMode + if foundEndpointNameMode != test.expectedEndpointMode { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with endpoints mode '%v'. Instead found endpoints mode '%v' for input '%s'", i, test.expectedEndpointMode, foundEndpointNameMode, test.input) + } + } +} + +func TestKubernetesParseNoEndpoints(t *testing.T) { + tests := []struct { + input string // Corefile data as string + shouldErr bool // true if test case is expected to produce an error. + expectedErrContent string // substring from the expected error. Empty for positive cases. + expectedEndpointsInit bool + }{ + // valid + { + `kubernetes coredns.local { + noendpoints +}`, + false, + "", + false, + }, + // invalid + { + `kubernetes coredns.local { + noendpoints ixnay on the endpointsay +}`, + true, + "rong argument count or unexpected", + true, + }, + // not set + { + `kubernetes coredns.local { +}`, + false, + "", + true, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + k8sController, err := kubernetesParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but did not find error for input '%s'. Error was: '%v'", i, test.input, err) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + continue + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + continue + } + + foundEndpointsInit := k8sController.opts.initEndpointsCache + if foundEndpointsInit != test.expectedEndpointsInit { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with endpoints watch '%v'. Instead found endpoints watch '%v' for input '%s'", i, test.expectedEndpointsInit, foundEndpointsInit, test.input) + } + } +} + +func TestKubernetesParseIgnoreEmptyService(t *testing.T) { + tests := []struct { + input string // Corefile data as string + shouldErr bool // true if test case is expected to produce an error. + expectedErrContent string // substring from the expected error. Empty for positive cases. + expectedEndpointsInit bool + }{ + // valid + { + `kubernetes coredns.local { + ignore empty_service +}`, + false, + "", + true, + }, + // invalid + { + `kubernetes coredns.local { + ignore ixnay on the endpointsay +}`, + true, + "unable to parse ignore value", + false, + }, + { + `kubernetes coredns.local { + ignore empty_service ixnay on the endpointsay +}`, + false, + "", + true, + }, + // not set + { + `kubernetes coredns.local { +}`, + false, + "", + false, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + k8sController, err := kubernetesParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but did not find error for input '%s'. Error was: '%v'", i, test.input, err) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + continue + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + continue + } + + foundIgnoreEmptyService := k8sController.opts.ignoreEmptyService + if foundIgnoreEmptyService != test.expectedEndpointsInit { + t.Errorf("Test %d: Expected kubernetes controller to be initialized with ignore empty_service '%v'. Instead found ignore empty_service watch '%v' for input '%s'", i, test.expectedEndpointsInit, foundIgnoreEmptyService, test.input) + } + } +} diff --git a/plugin/kubernetes/setup_ttl_test.go b/plugin/kubernetes/setup_ttl_test.go new file mode 100644 index 0000000..16b9b4a --- /dev/null +++ b/plugin/kubernetes/setup_ttl_test.go @@ -0,0 +1,45 @@ +package kubernetes + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestKubernetesParseTTL(t *testing.T) { + tests := []struct { + input string // Corefile data as string + expectedTTL uint32 // expected count of defined zones. + shouldErr bool + }{ + {`kubernetes cluster.local { + ttl 56 + }`, 56, false}, + {`kubernetes cluster.local`, defaultTTL, false}, + {`kubernetes cluster.local { + ttl -1 + }`, 0, true}, + {`kubernetes cluster.local { + ttl 3601 + }`, 0, true}, + } + + for i, tc := range tests { + c := caddy.NewTestController("dns", tc.input) + k, err := kubernetesParse(c) + if err != nil && !tc.shouldErr { + t.Fatalf("Test %d: Expected no error, got %q", i, err) + } + if err == nil && tc.shouldErr { + t.Fatalf("Test %d: Expected error, got none", i) + } + if err != nil && tc.shouldErr { + // input should error + continue + } + + if k.ttl != tc.expectedTTL { + t.Errorf("Test %d: Expected TTl to be %d, got %d", i, tc.expectedTTL, k.ttl) + } + } +} diff --git a/plugin/kubernetes/xfr.go b/plugin/kubernetes/xfr.go new file mode 100644 index 0000000..4d941a9 --- /dev/null +++ b/plugin/kubernetes/xfr.go @@ -0,0 +1,195 @@ +package kubernetes + +import ( + "context" + "math" + "net" + "sort" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/etcd/msg" + "github.com/coredns/coredns/plugin/transfer" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + api "k8s.io/api/core/v1" +) + +// Transfer implements the transfer.Transfer interface. +func (k *Kubernetes) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + match := plugin.Zones(k.Zones).Matches(zone) + if match == "" { + return nil, transfer.ErrNotAuthoritative + } + // state is not used here, hence the empty request.Request{] + soa, err := plugin.SOA(context.TODO(), k, zone, request.Request{}, plugin.Options{}) + if err != nil { + return nil, transfer.ErrNotAuthoritative + } + + ch := make(chan []dns.RR) + + zonePath := msg.Path(zone, "coredns") + serviceList := k.APIConn.ServiceList() + + go func() { + // ixfr fallback + if serial != 0 && soa[0].(*dns.SOA).Serial == serial { + ch <- soa + close(ch) + return + } + ch <- soa + + nsAddrs := k.nsAddrs(false, false, zone) + nsHosts := make(map[string]struct{}) + for _, nsAddr := range nsAddrs { + nsHost := nsAddr.Header().Name + if _, ok := nsHosts[nsHost]; !ok { + nsHosts[nsHost] = struct{}{} + ch <- []dns.RR{&dns.NS{Hdr: dns.RR_Header{Name: zone, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: k.ttl}, Ns: nsHost}} + } + ch <- nsAddrs + } + + sort.Slice(serviceList, func(i, j int) bool { + return serviceList[i].Name < serviceList[j].Name + }) + + for _, svc := range serviceList { + if !k.namespaceExposed(svc.Namespace) { + continue + } + svcBase := []string{zonePath, Svc, svc.Namespace, svc.Name} + switch svc.Type { + case api.ServiceTypeClusterIP, api.ServiceTypeNodePort, api.ServiceTypeLoadBalancer: + clusterIP := net.ParseIP(svc.ClusterIPs[0]) + if clusterIP != nil { + var host string + for _, ip := range svc.ClusterIPs { + s := msg.Service{Host: ip, TTL: k.ttl} + s.Key = strings.Join(svcBase, "/") + + // Change host from IP to Name for SRV records + host = emitAddressRecord(ch, s) + } + + for _, p := range svc.Ports { + s := msg.Service{Host: host, Port: int(p.Port), TTL: k.ttl} + s.Key = strings.Join(svcBase, "/") + + // Need to generate this to handle use cases for peer-finder + // ref: https://github.com/coredns/coredns/pull/823 + ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), 100)} + + // As per spec unnamed ports do not have a srv record + // https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records + if p.Name == "" { + continue + } + + s.Key = strings.Join(append(svcBase, strings.ToLower("_"+string(p.Protocol)), strings.ToLower("_"+p.Name)), "/") + + ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), 100)} + } + + // Skip endpoint discovery if clusterIP is defined + continue + } + + endpointsList := k.APIConn.EpIndex(svc.Name + "." + svc.Namespace) + + for _, ep := range endpointsList { + for _, eps := range ep.Subsets { + srvWeight := calcSRVWeight(len(eps.Addresses)) + for _, addr := range eps.Addresses { + s := msg.Service{Host: addr.IP, TTL: k.ttl} + s.Key = strings.Join(svcBase, "/") + // We don't need to change the msg.Service host from IP to Name yet + // so disregard the return value here + emitAddressRecord(ch, s) + + s.Key = strings.Join(append(svcBase, endpointHostname(addr, k.endpointNameMode)), "/") + // Change host from IP to Name for SRV records + host := emitAddressRecord(ch, s) + s.Host = host + + for _, p := range eps.Ports { + // As per spec unnamed ports do not have a srv record + // https://github.com/kubernetes/dns/blob/master/docs/specification.md#232---srv-records + if p.Name == "" { + continue + } + + s.Port = int(p.Port) + + s.Key = strings.Join(append(svcBase, strings.ToLower("_"+p.Protocol), strings.ToLower("_"+p.Name)), "/") + ch <- []dns.RR{s.NewSRV(msg.Domain(s.Key), srvWeight)} + } + } + } + } + + case api.ServiceTypeExternalName: + + s := msg.Service{Key: strings.Join(svcBase, "/"), Host: svc.ExternalName, TTL: k.ttl} + if t, _ := s.HostType(); t == dns.TypeCNAME { + ch <- []dns.RR{s.NewCNAME(msg.Domain(s.Key), s.Host)} + } + } + } + ch <- soa + close(ch) + }() + return ch, nil +} + +// emitAddressRecord generates a new A or AAAA record based on the msg.Service and writes it to a channel. +// emitAddressRecord returns the host name from the generated record. +func emitAddressRecord(c chan<- []dns.RR, s msg.Service) string { + ip := net.ParseIP(s.Host) + dnsType, _ := s.HostType() + switch dnsType { + case dns.TypeA: + r := s.NewA(msg.Domain(s.Key), ip) + c <- []dns.RR{r} + return r.Hdr.Name + case dns.TypeAAAA: + r := s.NewAAAA(msg.Domain(s.Key), ip) + c <- []dns.RR{r} + return r.Hdr.Name + } + + return "" +} + +// calcSrvWeight borrows the logic implemented in plugin.SRV for dynamically +// calculating the srv weight and priority +func calcSRVWeight(numservices int) uint16 { + var services []msg.Service + + for i := 0; i < numservices; i++ { + services = append(services, msg.Service{}) + } + + w := make(map[int]int) + for _, serv := range services { + weight := 100 + if serv.Weight != 0 { + weight = serv.Weight + } + if _, ok := w[serv.Priority]; !ok { + w[serv.Priority] = weight + continue + } + w[serv.Priority] += weight + } + weight := uint16(math.Floor((100.0 / float64(w[0])) * 100)) + // weight should be at least 1 + if weight == 0 { + weight = 1 + } + + return weight +} diff --git a/plugin/kubernetes/xfr_test.go b/plugin/kubernetes/xfr_test.go new file mode 100644 index 0000000..61e5d0a --- /dev/null +++ b/plugin/kubernetes/xfr_test.go @@ -0,0 +1,156 @@ +package kubernetes + +import ( + "net" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/transfer" + + "github.com/miekg/dns" +) + +func TestKubernetesTransferNonAuthZone(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Namespaces = map[string]struct{}{"testns": {}, "kube-system": {}} + k.localIPs = []net.IP{net.ParseIP("10.0.0.10")} + + dnsmsg := &dns.Msg{} + dnsmsg.SetAxfr("example.com") + + _, err := k.Transfer("example.com", 0) + if err != transfer.ErrNotAuthoritative { + t.Error(err) + } +} + +func TestKubernetesAXFR(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Namespaces = map[string]struct{}{"testns": {}, "kube-system": {}} + k.localIPs = []net.IP{net.ParseIP("10.0.0.10")} + + dnsmsg := &dns.Msg{} + dnsmsg.SetAxfr(k.Zones[0]) + + ch, err := k.Transfer(k.Zones[0], 0) + if err != nil { + t.Error(err) + } + validateAXFR(t, ch) +} + +func TestKubernetesIXFRFallback(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Namespaces = map[string]struct{}{"testns": {}, "kube-system": {}} + k.localIPs = []net.IP{net.ParseIP("10.0.0.10")} + + dnsmsg := &dns.Msg{} + dnsmsg.SetAxfr(k.Zones[0]) + + ch, err := k.Transfer(k.Zones[0], 1) + if err != nil { + t.Error(err) + } + validateAXFR(t, ch) +} + +func TestKubernetesIXFRCurrent(t *testing.T) { + k := New([]string{"cluster.local."}) + k.APIConn = &APIConnServeTest{} + k.Namespaces = map[string]struct{}{"testns": {}, "kube-system": {}} + k.localIPs = []net.IP{net.ParseIP("10.0.0.10")} + + dnsmsg := &dns.Msg{} + dnsmsg.SetAxfr(k.Zones[0]) + + ch, err := k.Transfer(k.Zones[0], 3) + if err != nil { + t.Error(err) + } + + var gotRRs []dns.RR + for rrs := range ch { + gotRRs = append(gotRRs, rrs...) + } + + // ensure only one record is returned + if len(gotRRs) > 1 { + t.Errorf("Expected only one answer, got %d", len(gotRRs)) + } + + // Ensure first record is a SOA + if gotRRs[0].Header().Rrtype != dns.TypeSOA { + t.Error("Invalid transfer response, does not start with SOA record") + } +} + +func validateAXFR(t *testing.T, ch <-chan []dns.RR) { + xfr := []dns.RR{} + for rrs := range ch { + xfr = append(xfr, rrs...) + } + if xfr[0].Header().Rrtype != dns.TypeSOA { + t.Error("Invalid transfer response, does not start with SOA record") + } + + zp := dns.NewZoneParser(strings.NewReader(expectedZone), "", "") + i := 0 + for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { + if !dns.IsDuplicate(rr, xfr[i]) { + t.Fatalf("Record %d, expected\n%v\n, got\n%v", i, rr, xfr[i]) + } + i++ + } + + if err := zp.Err(); err != nil { + t.Fatal(err) + } +} + +const expectedZone = ` +cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 3 7200 1800 86400 5 +cluster.local. 5 IN NS ns.dns.cluster.local. +ns.dns.cluster.local. 5 IN A 10.0.0.10 +external.testns.svc.cluster.local. 5 IN CNAME ext.interwebs.test. +external-to-service.testns.svc.cluster.local. 5 IN CNAME svc1.testns.svc.cluster.local. +hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.2 +172-0-0-2.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.2 +_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 172-0-0-2.hdls1.testns.svc.cluster.local. +hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.3 +172-0-0-3.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.3 +_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 172-0-0-3.hdls1.testns.svc.cluster.local. +hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.4 +dup-name.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.4 +_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 dup-name.hdls1.testns.svc.cluster.local. +hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.5 +dup-name.hdls1.testns.svc.cluster.local. 5 IN A 172.0.0.5 +_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 dup-name.hdls1.testns.svc.cluster.local. +hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::1 +5678-abcd--1.hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::1 +_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 5678-abcd--1.hdls1.testns.svc.cluster.local. +hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::2 +5678-abcd--2.hdls1.testns.svc.cluster.local. 5 IN AAAA 5678:abcd::2 +_http._tcp.hdls1.testns.svc.cluster.local. 5 IN SRV 0 16 80 5678-abcd--2.hdls1.testns.svc.cluster.local. +hdlsprtls.testns.svc.cluster.local. 5 IN A 172.0.0.20 +172-0-0-20.hdlsprtls.testns.svc.cluster.local. 5 IN A 172.0.0.20 +kubedns.kube-system.svc.cluster.local. 5 IN A 10.0.0.10 +kubedns.kube-system.svc.cluster.local. 5 IN SRV 0 100 53 kubedns.kube-system.svc.cluster.local. +_dns._udp.kubedns.kube-system.svc.cluster.local. 5 IN SRV 0 100 53 kubedns.kube-system.svc.cluster.local. +svc-dual-stack.testns.svc.cluster.local. 5 IN A 10.0.0.3 +svc-dual-stack.testns.svc.cluster.local. 5 IN AAAA 10::3 +svc-dual-stack.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc-dual-stack.testns.svc.cluster.local. +_http._tcp.svc-dual-stack.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc-dual-stack.testns.svc.cluster.local. +svc1.testns.svc.cluster.local. 5 IN A 10.0.0.1 +svc1.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc1.testns.svc.cluster.local. +_http._tcp.svc1.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc1.testns.svc.cluster.local. +svc6.testns.svc.cluster.local. 5 IN AAAA 1234:abcd::1 +svc6.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc6.testns.svc.cluster.local. +_http._tcp.svc6.testns.svc.cluster.local. 5 IN SRV 0 100 80 svc6.testns.svc.cluster.local. +svcempty.testns.svc.cluster.local. 5 IN A 10.0.0.1 +svcempty.testns.svc.cluster.local. 5 IN SRV 0 100 80 svcempty.testns.svc.cluster.local. +_http._tcp.svcempty.testns.svc.cluster.local. 5 IN SRV 0 100 80 svcempty.testns.svc.cluster.local. +cluster.local. 5 IN SOA ns.dns.cluster.local. hostmaster.cluster.local. 3 7200 1800 86400 5 +` diff --git a/plugin/loadbalance/README.md b/plugin/loadbalance/README.md new file mode 100644 index 0000000..fe29b19 --- /dev/null +++ b/plugin/loadbalance/README.md @@ -0,0 +1,90 @@ +# loadbalance + +## Name + +*loadbalance* - randomizes the order of A, AAAA and MX records. + +## Description + +The *loadbalance* will act as a round-robin DNS load balancer by randomizing the order of A, AAAA, +and MX records in the answer. + +See [Wikipedia](https://en.wikipedia.org/wiki/Round-robin_DNS) about the pros and cons of this +setup. It will take care to sort any CNAMEs before any address records, because some stub resolver +implementations (like glibc) are particular about that. + +## Syntax + +~~~ +loadbalance [round_robin | weighted WEIGHTFILE] { + reload DURATION +} +~~~ +* `round_robin` policy randomizes the order of A, AAAA, and MX records applying a uniform probability distribution. This is the default load balancing policy. + +* `weighted` policy assigns weight values to IPs to control the relative likelihood of particular IPs to be returned as the first +(top) A/AAAA record in the answer. Note that it does not shuffle all the records in the answer, it is only concerned about the first A/AAAA record +returned in the answer. + + * **WEIGHTFILE** is the file containing the weight values assigned to IPs for various domain names. If the path is relative, the path from the **root** plugin will be prepended to it. The format is explained below in the *Weightfile* section. + + * **DURATION** interval to reload `WEIGHTFILE` and update weight assignments if there are changes in the file. The default value is `30s`. A value of `0s` means to not scan for changes and reload. + + +## Weightfile + +The generic weight file syntax: + +~~~ +# Comment lines are ignored + +domain-name1 +ip11 weight11 +ip12 weight12 +ip13 weight13 + +domain-name2 +ip21 weight21 +ip22 weight22 +# ... etc. +~~~ + +where `ipXY` is an IP address for `domain-nameX` and `weightXY` is the weight value associated with that IP. The weight values are in the range of [1,255]. + +The `weighted` policy selects one of the address record in the result list and moves it to the top (first) position in the list. The random selection takes into account the weight values assigned to the addresses in the weight file. If an address in the result list is associated with no weight value in the weight file then the default weight value "1" is assumed for it when the selection is performed. + + +## Examples + +Load balance replies coming back from Google Public DNS: + +~~~ corefile +. { + loadbalance round_robin + forward . 8.8.8.8 8.8.4.4 +} +~~~ + +Use the `weighted` strategy to load balance replies supplied by the **file** plugin. We assign weight vales `3`, `1` and `2` to the IPs `100.64.1.1`, `100.64.1.2` and `100.64.1.3`, respectively. These IPs are addresses in A records for the domain name `www.example.com` defined in the `./db.example.com` zone file. The ratio between the number of answers in which `100.64.1.1`, `100.64.1.2` or `100.64.1.3` is in the top (first) A record should converge to `3 : 1 : 2`. (E.g. there should be twice as many answers with `100.64.1.3` in the top A record than with `100.64.1.2`). +Corefile: + +~~~ corefile +example.com { + file ./db.example.com { + reload 10s + } + loadbalance weighted ./db.example.com.weights { + reload 10s + } +} +~~~ + +weight file `./db.example.com.weights`: + +~~~ +www.example.com +100.64.1.1 3 +100.64.1.2 1 +100.64.1.3 2 +~~~ + diff --git a/plugin/loadbalance/handler.go b/plugin/loadbalance/handler.go new file mode 100644 index 0000000..8b84e1c --- /dev/null +++ b/plugin/loadbalance/handler.go @@ -0,0 +1,25 @@ +// Package loadbalance is a plugin for rewriting responses to do "load balancing" +package loadbalance + +import ( + "context" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +// RoundRobin is a plugin to rewrite responses for "load balancing". +type LoadBalance struct { + Next plugin.Handler + shuffle func(*dns.Msg) *dns.Msg +} + +// ServeDNS implements the plugin.Handler interface. +func (lb LoadBalance) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rw := &LoadBalanceResponseWriter{ResponseWriter: w, shuffle: lb.shuffle} + return plugin.NextOrFailure(lb.Name(), lb.Next, ctx, rw, r) +} + +// Name implements the Handler interface. +func (lb LoadBalance) Name() string { return "loadbalance" } diff --git a/plugin/loadbalance/loadbalance.go b/plugin/loadbalance/loadbalance.go new file mode 100644 index 0000000..f2a1cae --- /dev/null +++ b/plugin/loadbalance/loadbalance.go @@ -0,0 +1,91 @@ +// Package loadbalance shuffles A, AAAA and MX records. +package loadbalance + +import ( + "github.com/miekg/dns" +) + +const ( + ramdomShufflePolicy = "round_robin" + weightedRoundRobinPolicy = "weighted" +) + +// LoadBalanceResponseWriter is a response writer that shuffles A, AAAA and MX records. +type LoadBalanceResponseWriter struct { + dns.ResponseWriter + shuffle func(*dns.Msg) *dns.Msg +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (r *LoadBalanceResponseWriter) WriteMsg(res *dns.Msg) error { + if res.Rcode != dns.RcodeSuccess { + return r.ResponseWriter.WriteMsg(res) + } + + if res.Question[0].Qtype == dns.TypeAXFR || res.Question[0].Qtype == dns.TypeIXFR { + return r.ResponseWriter.WriteMsg(res) + } + + return r.ResponseWriter.WriteMsg(r.shuffle(res)) +} + +func randomShuffle(res *dns.Msg) *dns.Msg { + res.Answer = roundRobin(res.Answer) + res.Ns = roundRobin(res.Ns) + res.Extra = roundRobin(res.Extra) + return res +} + +func roundRobin(in []dns.RR) []dns.RR { + cname := []dns.RR{} + address := []dns.RR{} + mx := []dns.RR{} + rest := []dns.RR{} + for _, r := range in { + switch r.Header().Rrtype { + case dns.TypeCNAME: + cname = append(cname, r) + case dns.TypeA, dns.TypeAAAA: + address = append(address, r) + case dns.TypeMX: + mx = append(mx, r) + default: + rest = append(rest, r) + } + } + + roundRobinShuffle(address) + roundRobinShuffle(mx) + + out := append(cname, rest...) + out = append(out, address...) + out = append(out, mx...) + return out +} + +func roundRobinShuffle(records []dns.RR) { + switch l := len(records); l { + case 0, 1: + break + case 2: + if dns.Id()%2 == 0 { + records[0], records[1] = records[1], records[0] + } + default: + for j := 0; j < l; j++ { + p := j + (int(dns.Id()) % (l - j)) + if j == p { + continue + } + records[j], records[p] = records[p], records[j] + } + } +} + +// Write implements the dns.ResponseWriter interface. +func (r *LoadBalanceResponseWriter) Write(buf []byte) (int, error) { + // Should we pack and unpack here to fiddle with the packet... Not likely. + log.Warning("LoadBalance called with Write: not shuffling records") + n, err := r.ResponseWriter.Write(buf) + return n, err +} diff --git a/plugin/loadbalance/loadbalance_test.go b/plugin/loadbalance/loadbalance_test.go new file mode 100644 index 0000000..c46d968 --- /dev/null +++ b/plugin/loadbalance/loadbalance_test.go @@ -0,0 +1,203 @@ +package loadbalance + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestLoadBalanceRandom(t *testing.T) { + rm := LoadBalance{Next: handler(), shuffle: randomShuffle} + + // the first X records must be cnames after this test + tests := []struct { + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + addressAnswer int + addressExtra int + mxAnswer int + mxExtra int + }{ + { + answer: []dns.RR{ + test.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + test.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."), + test.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."), + }, + cnameAnswer: 4, + addressAnswer: 1, + mxAnswer: 3, + }, + { + answer: []dns.RR{ + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + }, + cnameAnswer: 1, + addressAnswer: 1, + mxAnswer: 1, + }, + { + answer: []dns.RR{ + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), + }, + extra: []dns.RR{ + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.AAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + test.AAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), + }, + cnameAnswer: 1, + cnameExtra: 1, + addressAnswer: 3, + addressExtra: 4, + mxAnswer: 3, + mxExtra: 3, + }, + } + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + for i, test := range tests { + req := new(dns.Msg) + req.SetQuestion("region2.skydns.test.", dns.TypeSRV) + req.Answer = test.answer + req.Extra = test.extra + + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + cname, address, mx, sorted := countRecords(rec.Msg.Answer) + if !sorted { + t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i) + } + if cname != test.cnameAnswer { + t.Errorf("Test %d: Expected %d CNAMEs in Answer, but got %d", i, test.cnameAnswer, cname) + } + if address != test.addressAnswer { + t.Errorf("Test %d: Expected %d A/AAAAs in Answer, but got %d", i, test.addressAnswer, address) + } + if mx != test.mxAnswer { + t.Errorf("Test %d: Expected %d MXs in Answer, but got %d", i, test.mxAnswer, mx) + } + + cname, address, mx, sorted = countRecords(rec.Msg.Extra) + if !sorted { + t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Extra, but got mixed", i) + } + if cname != test.cnameExtra { + t.Errorf("Test %d: Expected %d CNAMEs in Extra, but got %d", i, test.cnameAnswer, cname) + } + if address != test.addressExtra { + t.Errorf("Test %d: Expected %d A/AAAAs in Extra, but got %d", i, test.addressAnswer, address) + } + if mx != test.mxExtra { + t.Errorf("Test %d: Expected %d MXs in Extra, but got %d", i, test.mxAnswer, mx) + } + } +} + +func TestLoadBalanceXFR(t *testing.T) { + rm := LoadBalance{Next: handler()} + + answer := []dns.RR{ + test.SOA("skydns.test. 30 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1542756695 7200 1800 86400 30"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."), + test.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + test.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + test.MX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."), + test.SOA("skydns.test. 30 IN SOA ns.dns.skydns.test. hostmaster.skydns.test. 1542756695 7200 1800 86400 30"), + } + + for _, xfrtype := range []uint16{dns.TypeIXFR, dns.TypeAXFR} { + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := new(dns.Msg) + req.SetQuestion("skydns.test.", xfrtype) + req.Answer = answer + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Expected no error, but got %s for %s", err, dns.TypeToString[xfrtype]) + continue + } + + if rec.Msg.Answer[0].Header().Rrtype != dns.TypeSOA { + t.Errorf("Expected SOA record for first answer for %s", dns.TypeToString[xfrtype]) + } + + if rec.Msg.Answer[len(rec.Msg.Answer)-1].Header().Rrtype != dns.TypeSOA { + t.Errorf("Expected SOA record for last answer for %s", dns.TypeToString[xfrtype]) + } + } +} + +func countRecords(result []dns.RR) (cname int, address int, mx int, sorted bool) { + const ( + Start = iota + CNAMERecords + ARecords + MXRecords + Any + ) + + // The order of the records is used to determine if the round-robin actually did anything. + sorted = true + cname = 0 + address = 0 + mx = 0 + state := Start + for _, r := range result { + switch r.Header().Rrtype { + case dns.TypeCNAME: + sorted = sorted && state <= CNAMERecords + state = CNAMERecords + cname++ + case dns.TypeA, dns.TypeAAAA: + sorted = sorted && state <= ARecords + state = ARecords + address++ + case dns.TypeMX: + sorted = sorted && state <= MXRecords + state = MXRecords + mx++ + default: + state = Any + } + } + return +} + +func handler() plugin.Handler { + return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + w.WriteMsg(r) + return dns.RcodeSuccess, nil + }) +} diff --git a/plugin/loadbalance/log_test.go b/plugin/loadbalance/log_test.go new file mode 100644 index 0000000..e4dbd6d --- /dev/null +++ b/plugin/loadbalance/log_test.go @@ -0,0 +1,5 @@ +package loadbalance + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/loadbalance/setup.go b/plugin/loadbalance/setup.go new file mode 100644 index 0000000..5706aeb --- /dev/null +++ b/plugin/loadbalance/setup.go @@ -0,0 +1,103 @@ +package loadbalance + +import ( + "errors" + "fmt" + "path/filepath" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("loadbalance") +var errOpen = errors.New("Weight file open error") + +func init() { plugin.Register("loadbalance", setup) } + +type lbFuncs struct { + shuffleFunc func(*dns.Msg) *dns.Msg + onStartUpFunc func() error + onShutdownFunc func() error + weighted *weightedRR // used in unit tests only +} + +func setup(c *caddy.Controller) error { + //shuffleFunc, startUpFunc, shutdownFunc, err := parse(c) + lb, err := parse(c) + if err != nil { + return plugin.Error("loadbalance", err) + } + if lb.onStartUpFunc != nil { + c.OnStartup(lb.onStartUpFunc) + } + if lb.onShutdownFunc != nil { + c.OnShutdown(lb.onShutdownFunc) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return LoadBalance{Next: next, shuffle: lb.shuffleFunc} + }) + + return nil +} + +// func parse(c *caddy.Controller) (string, *weightedRR, error) { +func parse(c *caddy.Controller) (*lbFuncs, error) { + config := dnsserver.GetConfig(c) + + for c.Next() { + args := c.RemainingArgs() + if len(args) == 0 { + return &lbFuncs{shuffleFunc: randomShuffle}, nil + } + switch args[0] { + case ramdomShufflePolicy: + if len(args) > 1 { + return nil, c.Errf("unknown property for %s", args[0]) + } + return &lbFuncs{shuffleFunc: randomShuffle}, nil + case weightedRoundRobinPolicy: + if len(args) < 2 { + return nil, c.Err("missing weight file argument") + } + + if len(args) > 2 { + return nil, c.Err("unexpected argument(s)") + } + + weightFileName := args[1] + if !filepath.IsAbs(weightFileName) && config.Root != "" { + weightFileName = filepath.Join(config.Root, weightFileName) + } + reload := 30 * time.Second // default reload period + for c.NextBlock() { + switch c.Val() { + case "reload": + t := c.RemainingArgs() + if len(t) < 1 { + return nil, c.Err("reload duration value is missing") + } + if len(t) > 1 { + return nil, c.Err("unexpected argument") + } + var err error + reload, err = time.ParseDuration(t[0]) + if err != nil { + return nil, c.Errf("invalid reload duration '%s'", t[0]) + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + return createWeightedFuncs(weightFileName, reload), nil + default: + return nil, fmt.Errorf("unknown policy: %s", args[0]) + } + } + return nil, c.ArgErr() +} diff --git a/plugin/loadbalance/setup_test.go b/plugin/loadbalance/setup_test.go new file mode 100644 index 0000000..4e3c99c --- /dev/null +++ b/plugin/loadbalance/setup_test.go @@ -0,0 +1,100 @@ +package loadbalance + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +// weighted round robin specific test data +var testWeighted = []struct { + expectedWeightFile string + expectedWeightReload string +}{ + {"wfile", "30s"}, + {"wf", "10s"}, + {"wf", "0s"}, +} + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedPolicy string + expectedErrContent string // substring from the expected error. Empty for positive cases. + weightedDataIndex int // weighted round robin specific data index + }{ + // positive + {`loadbalance`, false, "round_robin", "", -1}, + {`loadbalance round_robin`, false, "round_robin", "", -1}, + {`loadbalance weighted wfile`, false, "weighted", "", 0}, + {`loadbalance weighted wf { + reload 10s + } `, false, "weighted", "", 1}, + {`loadbalance weighted wf { + reload 0s + } `, false, "weighted", "", 2}, + // negative + {`loadbalance fleeb`, true, "", "unknown policy", -1}, + {`loadbalance round_robin a`, true, "", "unknown property", -1}, + {`loadbalance weighted`, true, "", "missing weight file argument", -1}, + {`loadbalance weighted a b`, true, "", "unexpected argument", -1}, + {`loadbalance weighted wfile { + susu + } `, true, "", "unknown property", -1}, + {`loadbalance weighted wfile { + reload a + } `, true, "", "invalid reload duration", -1}, + {`loadbalance weighted wfile { + reload 30s a + } `, true, "", "unexpected argument", -1}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + lb, err := parse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", + i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", + i, test.expectedErrContent, err, test.input) + } + continue + } + + if lb == nil { + t.Errorf("Test %d: Expected valid loadbalance funcs but got nil for input %s", + i, test.input) + continue + } + policy := ramdomShufflePolicy + if lb.weighted != nil { + policy = weightedRoundRobinPolicy + } + if policy != test.expectedPolicy { + t.Errorf("Test %d: Expected policy %s but got %s for input %s", i, + test.expectedPolicy, policy, test.input) + } + if policy == weightedRoundRobinPolicy && test.weightedDataIndex >= 0 { + i := test.weightedDataIndex + if testWeighted[i].expectedWeightFile != lb.weighted.fileName { + t.Errorf("Test %d: Expected weight file name %s but got %s for input %s", + i, testWeighted[i].expectedWeightFile, lb.weighted.fileName, test.input) + } + if testWeighted[i].expectedWeightReload != lb.weighted.reload.String() { + t.Errorf("Test %d: Expected weight reload duration %s but got %s for input %s", + i, testWeighted[i].expectedWeightReload, lb.weighted.reload, test.input) + } + } + } +} diff --git a/plugin/loadbalance/weighted.go b/plugin/loadbalance/weighted.go new file mode 100644 index 0000000..39b1780 --- /dev/null +++ b/plugin/loadbalance/weighted.go @@ -0,0 +1,329 @@ +package loadbalance + +import ( + "bufio" + "bytes" + "crypto/md5" + "errors" + "fmt" + "io" + "math/rand" + "net" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +type ( + // "weighted-round-robin" policy specific data + weightedRR struct { + fileName string + reload time.Duration + md5sum [md5.Size]byte + domains map[string]weights + randomGen + mutex sync.Mutex + } + // Per domain weights + weights []*weightItem + // Weight assigned to an address + weightItem struct { + address net.IP + value uint8 + } + // Random uint generator + randomGen interface { + randInit() + randUint(limit uint) uint + } +) + +// Random uint generator +type randomUint struct { + rn *rand.Rand +} + +func (r *randomUint) randInit() { + r.rn = rand.New(rand.NewSource(time.Now().UnixNano())) +} + +func (r *randomUint) randUint(limit uint) uint { + return uint(r.rn.Intn(int(limit))) +} + +func weightedShuffle(res *dns.Msg, w *weightedRR) *dns.Msg { + switch res.Question[0].Qtype { + case dns.TypeA, dns.TypeAAAA, dns.TypeSRV: + res.Answer = w.weightedRoundRobin(res.Answer) + res.Extra = w.weightedRoundRobin(res.Extra) + } + return res +} + +func weightedOnStartUp(w *weightedRR, stopReloadChan chan bool) error { + err := w.updateWeights() + if errors.Is(err, errOpen) && w.reload != 0 { + log.Warningf("Failed to open weight file:%v. Will try again in %v", + err, w.reload) + } else if err != nil { + return plugin.Error("loadbalance", err) + } + // start periodic weight file reload go routine + w.periodicWeightUpdate(stopReloadChan) + return nil +} + +func createWeightedFuncs(weightFileName string, + reload time.Duration) *lbFuncs { + lb := &lbFuncs{ + weighted: &weightedRR{ + fileName: weightFileName, + reload: reload, + randomGen: &randomUint{}, + }, + } + lb.weighted.randomGen.randInit() + + lb.shuffleFunc = func(res *dns.Msg) *dns.Msg { + return weightedShuffle(res, lb.weighted) + } + + stopReloadChan := make(chan bool) + + lb.onStartUpFunc = func() error { + return weightedOnStartUp(lb.weighted, stopReloadChan) + } + + lb.onShutdownFunc = func() error { + // stop periodic weigh reload go routine + close(stopReloadChan) + return nil + } + return lb +} + +// Apply weighted round robin policy to the answer +func (w *weightedRR) weightedRoundRobin(in []dns.RR) []dns.RR { + cname := []dns.RR{} + address := []dns.RR{} + mx := []dns.RR{} + rest := []dns.RR{} + for _, r := range in { + switch r.Header().Rrtype { + case dns.TypeCNAME: + cname = append(cname, r) + case dns.TypeA, dns.TypeAAAA: + address = append(address, r) + case dns.TypeMX: + mx = append(mx, r) + default: + rest = append(rest, r) + } + } + + if len(address) == 0 { + // no change + return in + } + + w.setTopRecord(address) + + out := append(cname, rest...) + out = append(out, address...) + out = append(out, mx...) + return out +} + +// Move the next expected address to the first position in the result list +func (w *weightedRR) setTopRecord(address []dns.RR) { + itop := w.topAddressIndex(address) + + if itop < 0 { + // internal error + return + } + + if itop != 0 { + // swap the selected top entry with the actual one + address[0], address[itop] = address[itop], address[0] + } +} + +// Compute the top (first) address index +func (w *weightedRR) topAddressIndex(address []dns.RR) int { + w.mutex.Lock() + defer w.mutex.Unlock() + + // Determine the weight value for each address in the answer + var wsum uint + type waddress struct { + index int + weight uint8 + } + weightedAddr := make([]waddress, len(address)) + for i, ar := range address { + wa := &weightedAddr[i] + wa.index = i + wa.weight = 1 // default weight + var ip net.IP + switch ar.Header().Rrtype { + case dns.TypeA: + ip = ar.(*dns.A).A + case dns.TypeAAAA: + ip = ar.(*dns.AAAA).AAAA + } + ws := w.domains[ar.Header().Name] + for _, w := range ws { + if w.address.Equal(ip) { + wa.weight = w.value + break + } + } + wsum += uint(wa.weight) + } + + // Select the first (top) IP + sort.Slice(weightedAddr, func(i, j int) bool { + return weightedAddr[i].weight > weightedAddr[j].weight + }) + v := w.randUint(wsum) + var psum uint + for _, wa := range weightedAddr { + psum += uint(wa.weight) + if v < psum { + return wa.index + } + } + + // we should never reach this + log.Errorf("Internal error: cannot find top address (randv:%v wsum:%v)", v, wsum) + return -1 +} + +// Start go routine to update weights from the weight file periodically +func (w *weightedRR) periodicWeightUpdate(stopReload <-chan bool) { + if w.reload == 0 { + return + } + + go func() { + ticker := time.NewTicker(w.reload) + for { + select { + case <-stopReload: + return + case <-ticker.C: + err := w.updateWeights() + if err != nil { + log.Error(err) + } + } + } + }() +} + +// Update weights from weight file +func (w *weightedRR) updateWeights() error { + reader, err := os.Open(filepath.Clean(w.fileName)) + if err != nil { + return errOpen + } + defer reader.Close() + + // check if the contents has changed + var buf bytes.Buffer + tee := io.TeeReader(reader, &buf) + bytes, err := io.ReadAll(tee) + if err != nil { + return err + } + md5sum := md5.Sum(bytes) + if md5sum == w.md5sum { + // file contents has not changed + return nil + } + w.md5sum = md5sum + scanner := bufio.NewScanner(&buf) + + // Parse the weight file contents + domains, err := w.parseWeights(scanner) + if err != nil { + return err + } + + // access to weights must be protected + w.mutex.Lock() + w.domains = domains + w.mutex.Unlock() + + log.Infof("Successfully reloaded weight file %s", w.fileName) + return nil +} + +// Parse the weight file contents +func (w *weightedRR) parseWeights(scanner *bufio.Scanner) (map[string]weights, error) { + var dname string + var ws weights + domains := make(map[string]weights) + + for scanner.Scan() { + nextLine := strings.TrimSpace(scanner.Text()) + if len(nextLine) == 0 || nextLine[0:1] == "#" { + // Empty and comment lines are ignored + continue + } + fields := strings.Fields(nextLine) + switch len(fields) { + case 1: + // (domain) name sanity check + if net.ParseIP(fields[0]) != nil { + return nil, fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)", + fields[0], w.fileName) + } + dname = fields[0] + + // add the root domain if it is missing + if dname[len(dname)-1] != '.' { + dname += "." + } + var ok bool + ws, ok = domains[dname] + if !ok { + ws = make(weights, 0) + domains[dname] = ws + } + case 2: + // IP address and weight value + ip := net.ParseIP(fields[0]) + if ip == nil { + return nil, fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName) + } + weight, err := strconv.ParseUint(fields[1], 10, 8) + if err != nil || weight == 0 { + return nil, fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName) + } + witem := &weightItem{address: ip, value: uint8(weight)} + if dname == "" { + return nil, fmt.Errorf("Missing domain name in weight file %s", w.fileName) + } + ws = append(ws, witem) + domains[dname] = ws + default: + return nil, fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName) + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err) + } + + return domains, nil +} diff --git a/plugin/loadbalance/weighted_test.go b/plugin/loadbalance/weighted_test.go new file mode 100644 index 0000000..fa596f0 --- /dev/null +++ b/plugin/loadbalance/weighted_test.go @@ -0,0 +1,430 @@ +package loadbalance + +import ( + "context" + "errors" + "net" + "strings" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + testutil "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +const oneDomainWRR = ` +w1,example.org +192.168.1.15 10 +192.168.1.14 20 +` + +var testOneDomainWRR = map[string]weights{ + "w1,example.org.": { + &weightItem{net.ParseIP("192.168.1.15"), uint8(10)}, + &weightItem{net.ParseIP("192.168.1.14"), uint8(20)}, + }, +} + +const twoDomainsWRR = ` +# domain 1 +w1.example.org +192.168.1.15 10 +192.168.1.14 20 + +# domain 2 +w2.example.org + # domain 3 + w3.example.org + 192.168.2.16 11 + 192.168.2.15 12 + 192.168.2.14 13 +` + +var testTwoDomainsWRR = map[string]weights{ + "w1.example.org.": { + &weightItem{net.ParseIP("192.168.1.15"), uint8(10)}, + &weightItem{net.ParseIP("192.168.1.14"), uint8(20)}, + }, + "w2.example.org.": {}, + "w3.example.org.": { + &weightItem{net.ParseIP("192.168.2.16"), uint8(11)}, + &weightItem{net.ParseIP("192.168.2.15"), uint8(12)}, + &weightItem{net.ParseIP("192.168.2.14"), uint8(13)}, + }, +} + +const missingWeightWRR = ` +w1,example.org +192.168.1.14 +192.168.1.15 20 +` + +const missingDomainWRR = ` +# missing domain +192.168.1.14 10 +w2,example.org +192.168.2.14 11 +192.168.2.15 12 +` + +const wrongIpWRR = ` +w1,example.org +192.168.1.300 10 +` + +const wrongWeightWRR = ` +w1,example.org +192.168.1.14 300 +` + +const zeroWeightWRR = ` +w1,example.org +192.168.1.14 0 +` + +func TestWeightFileUpdate(t *testing.T) { + tests := []struct { + weightFilContent string + shouldErr bool + expectedDomains map[string]weights + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + {"", false, nil, ""}, + {oneDomainWRR, false, testOneDomainWRR, ""}, + {twoDomainsWRR, false, testTwoDomainsWRR, ""}, + // negative + {missingWeightWRR, true, nil, "Wrong domain name"}, + {missingDomainWRR, true, nil, "Missing domain name"}, + {wrongIpWRR, true, nil, "Wrong IP address"}, + {wrongWeightWRR, true, nil, "Wrong weight value"}, + {zeroWeightWRR, true, nil, "Wrong weight value"}, + } + + for i, test := range tests { + testFile, rm, err := testutil.TempFile(".", test.weightFilContent) + if err != nil { + t.Fatal(err) + } + defer rm() + weighted := &weightedRR{fileName: testFile} + err = weighted.updateWeights() + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s", i, err) + } + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found error: %v", i, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v", + i, test.expectedErrContent, err) + } + } + if test.expectedDomains != nil { + if len(test.expectedDomains) != len(weighted.domains) { + t.Errorf("Test %d: Expected len(domains): %d but got %d", + i, len(test.expectedDomains), len(weighted.domains)) + } else { + _ = checkDomainsWRR(t, i, test.expectedDomains, weighted.domains) + } + } + } +} + +func checkDomainsWRR(t *testing.T, testIndex int, expectedDomains, domains map[string]weights) error { + var ret error + retError := errors.New("Check domains failed") + for dname, expectedWeights := range expectedDomains { + ws, ok := domains[dname] + if !ok { + t.Errorf("Test %d: Expected domain %s but not found it", testIndex, dname) + ret = retError + } else { + if len(expectedWeights) != len(ws) { + t.Errorf("Test %d: Expected len(weights): %d for domain %s but got %d", + testIndex, len(expectedWeights), dname, len(ws)) + ret = retError + } else { + for i, w := range expectedWeights { + if !w.address.Equal(ws[i].address) || w.value != ws[i].value { + t.Errorf("Test %d: Weight list differs at index %d for domain %s. "+ + "Expected: %v got: %v", testIndex, i, dname, expectedWeights[i], ws[i]) + ret = retError + } + } + } + } + } + + return ret +} + +func TestPeriodicWeightUpdate(t *testing.T) { + testFile1, rm, err := testutil.TempFile(".", oneDomainWRR) + if err != nil { + t.Fatal(err) + } + defer rm() + testFile2, rm, err := testutil.TempFile(".", twoDomainsWRR) + if err != nil { + t.Fatal(err) + } + defer rm() + + // configure weightedRR with "oneDomainWRR" weight file content + weighted := &weightedRR{fileName: testFile1} + + err = weighted.updateWeights() + if err != nil { + t.Fatal(err) + } else { + err = checkDomainsWRR(t, 0, testOneDomainWRR, weighted.domains) + if err != nil { + t.Fatalf("Initial check domains failed") + } + } + + // change weight file + weighted.fileName = testFile2 + // start periodic update + weighted.reload = 10 * time.Millisecond + stopChan := make(chan bool) + weighted.periodicWeightUpdate(stopChan) + time.Sleep(20 * time.Millisecond) + // stop periodic update + close(stopChan) + // check updated config + weighted.mutex.Lock() + err = checkDomainsWRR(t, 0, testTwoDomainsWRR, weighted.domains) + weighted.mutex.Unlock() + if err != nil { + t.Fatalf("Final check domains failed") + } +} + +// Fake random number generator for testing +type fakeRandomGen struct { + expectedLimit uint + testIndex int + queryIndex int + randv uint + t *testing.T +} + +func (r *fakeRandomGen) randInit() { +} + +func (r *fakeRandomGen) randUint(limit uint) uint { + if limit != r.expectedLimit { + r.t.Errorf("Test %d query %d: Expected weights sum %d but got %d", + r.testIndex, r.queryIndex, r.expectedLimit, limit) + } + return r.randv +} + +func TestLoadBalanceWRR(t *testing.T) { + type testQuery struct { + randv uint // fake random value for selecting the top IP + topIP string // top (first) address record in the answer + } + + // domain maps to test + oneDomain := map[string]weights{ + "endpoint.region2.skydns.test.": { + &weightItem{net.ParseIP("10.240.0.2"), uint8(3)}, + &weightItem{net.ParseIP("10.240.0.1"), uint8(2)}, + }, + } + twoDomains := map[string]weights{ + "endpoint.region2.skydns.test.": { + &weightItem{net.ParseIP("10.240.0.2"), uint8(5)}, + &weightItem{net.ParseIP("10.240.0.1"), uint8(2)}, + }, + "endpoint.region1.skydns.test.": { + &weightItem{net.ParseIP("::2"), uint8(4)}, + &weightItem{net.ParseIP("::1"), uint8(3)}, + }, + } + + // the first X records must be cnames after this test + tests := []struct { + answer []dns.RR + extra []dns.RR + cnameAnswer int + cnameExtra int + addressAnswer int + addressExtra int + mxAnswer int + mxExtra int + domains map[string]weights + sumWeights uint // sum of weights in the answer + queries []testQuery + }{ + { + answer: []dns.RR{ + testutil.CNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), + testutil.CNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), + testutil.CNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), + testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.MX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."), + testutil.MX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."), + }, + extra: []dns.RR{ + testutil.CNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::1"), + testutil.AAAA("endpoint.region1.skydns.test. 300 IN AAAA ::2"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + }, + cnameAnswer: 4, + cnameExtra: 1, + addressAnswer: 5, + addressExtra: 5, + mxAnswer: 3, + mxExtra: 1, + domains: twoDomains, + sumWeights: 15, + queries: []testQuery{ + {0, "10.240.0.2"}, // domain 1 weight 5 + {4, "10.240.0.2"}, // domain 1 weight 5 + {5, "::2"}, // domain 2 weight 4 + {8, "::2"}, // domain 2 weight 4 + {9, "::1"}, // domain 2 weight 3 + {11, "::1"}, // domain 2 weight 3 + {12, "10.240.0.1"}, // domain 1 weight 2 + {13, "10.240.0.1"}, // domain 1 weight 2 + {14, "10.240.0.3"}, // domain 1 no weight -> default weight + }, + }, + { + answer: []dns.RR{ + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + testutil.A("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), + testutil.A("endpoint.region1.skydns.test. 300 IN A 10.240.0.3"), + }, + cnameAnswer: 1, + addressAnswer: 3, + mxAnswer: 1, + domains: oneDomain, + sumWeights: 6, + queries: []testQuery{ + {0, "10.240.0.2"}, // weight 3 + {2, "10.240.0.2"}, // weight 3 + {3, "10.240.0.1"}, // weight 2 + {4, "10.240.0.1"}, // weight 2 + {5, "10.240.0.3"}, // no domain -> default weight + }, + }, + { + answer: []dns.RR{ + testutil.MX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."), + testutil.CNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), + }, + cnameAnswer: 1, + mxAnswer: 1, + domains: oneDomain, + queries: []testQuery{ + {0, ""}, // no address records -> answer unaltered + }, + }, + } + + testRand := &fakeRandomGen{t: t} + weighted := &weightedRR{randomGen: testRand} + shuffle := func(res *dns.Msg) *dns.Msg { + return weightedShuffle(res, weighted) + } + rm := LoadBalance{Next: handler(), shuffle: shuffle} + + rec := dnstest.NewRecorder(&testutil.ResponseWriter{}) + + for i, test := range tests { + // set domain map for weighted round robin + weighted.domains = test.domains + testRand.testIndex = i + testRand.expectedLimit = test.sumWeights + + for j, query := range test.queries { + req := new(dns.Msg) + req.SetQuestion("endpoint.region2.skydns.test", dns.TypeSRV) + req.Answer = test.answer + req.Extra = test.extra + + // Set fake random number + testRand.randv = query.randv + testRand.queryIndex = j + + _, err := rm.ServeDNS(context.TODO(), rec, req) + if err != nil { + t.Errorf("Test %d: Expected no error, but got %s", i, err) + continue + } + + checkTopIP(t, i, j, rec.Msg.Answer, query.topIP) + checkTopIP(t, i, j, rec.Msg.Extra, query.topIP) + + cname, address, mx, sorted := countRecords(rec.Msg.Answer) + if query.topIP != "" && !sorted { + t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j) + } + if cname != test.cnameAnswer { + t.Errorf("Test %d query %d: Expected %d CNAMEs in Answer, but got %d", i, j, test.cnameAnswer, cname) + } + if address != test.addressAnswer { + t.Errorf("Test %d query %d: Expected %d A/AAAAs in Answer, but got %d", i, j, test.addressAnswer, address) + } + if mx != test.mxAnswer { + t.Errorf("Test %d query %d: Expected %d MXs in Answer, but got %d", i, j, test.mxAnswer, mx) + } + + cname, address, mx, sorted = countRecords(rec.Msg.Extra) + if query.topIP != "" && !sorted { + t.Errorf("Test %d query %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i, j) + } + + if cname != test.cnameExtra { + t.Errorf("Test %d query %d: Expected %d CNAMEs in Extra, but got %d", i, j, test.cnameAnswer, cname) + } + if address != test.addressExtra { + t.Errorf("Test %d query %d: Expected %d A/AAAAs in Extra, but got %d", i, j, test.addressAnswer, address) + } + if mx != test.mxExtra { + t.Errorf("Test %d query %d: Expected %d MXs in Extra, but got %d", i, j, test.mxAnswer, mx) + } + } + } +} + +func checkTopIP(t *testing.T, i, j int, result []dns.RR, expectedTopIP string) { + expected := net.ParseIP(expectedTopIP) + for _, r := range result { + switch r.Header().Rrtype { + case dns.TypeA: + ar := r.(*dns.A) + if !ar.A.Equal(expected) { + t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.A) + } + return + case dns.TypeAAAA: + ar := r.(*dns.AAAA) + if !ar.AAAA.Equal(expected) { + t.Errorf("Test %d query %d: expected top IP %s but got %s", i, j, expectedTopIP, ar.AAAA) + } + return + } + } +} diff --git a/plugin/local/README.md b/plugin/local/README.md new file mode 100644 index 0000000..08fff01 --- /dev/null +++ b/plugin/local/README.md @@ -0,0 +1,52 @@ +# local + +## Name + +*local* - respond to local names. + +## Description + +*local* will respond with a basic reply to a "local request". Local request are defined to be +names in the following zones: localhost, 0.in-addr.arpa, 127.in-addr.arpa and 255.in-addr.arpa *and* +any query asking for `localhost.<domain>`. When seeing the latter a metric counter is increased and +if *debug* is enabled a debug log is emitted. + +With *local* enabled any query falling under these zones will get a reply. The prevents the query +from "escaping" to the internet and putting strain on external infrastructure. + +The zones are mostly empty, only `localhost.` address records (A and AAAA) are defined and a +`1.0.0.127.in-addr.arpa.` reverse (PTR) record. + +## Syntax + +~~~ txt +local +~~~ + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metric is exported: + +* `coredns_local_localhost_requests_total{}` - a counter of the number of `localhost.<domain>` + requests CoreDNS has seen. Note this does *not* count `localhost.` queries. + +Note that this metric *does not* have a `server` label, because it's more interesting to find the +client(s) performing these queries than to see which server handled it. You'll need to inspect the +debug log to get the client IP address. + +## Examples + +~~~ corefile +. { + local +} +~~~ + +## Bugs + +Only the `in-addr.arpa.` reverse zone is implemented, `ip6.arpa.` queries are not intercepted. + +## See Also + +BIND9's configuration in Debian comes with these zones preconfigured. See the *debug* plugin for +enabling debug logging. diff --git a/plugin/local/local.go b/plugin/local/local.go new file mode 100644 index 0000000..570f113 --- /dev/null +++ b/plugin/local/local.go @@ -0,0 +1,127 @@ +package local + +import ( + "context" + "net" + "strings" + + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("local") + +// Local is a plugin that returns standard replies for local queries. +type Local struct { + Next plugin.Handler +} + +var zones = []string{"localhost.", "0.in-addr.arpa.", "127.in-addr.arpa.", "255.in-addr.arpa."} + +func soaFromOrigin(origin string) []dns.RR { + hdr := dns.RR_Header{Name: origin, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeSOA} + return []dns.RR{&dns.SOA{Hdr: hdr, Ns: "localhost.", Mbox: "root.localhost.", Serial: 1, Refresh: 0, Retry: 0, Expire: 0, Minttl: ttl}} +} + +func nsFromOrigin(origin string) []dns.RR { + hdr := dns.RR_Header{Name: origin, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeNS} + return []dns.RR{&dns.NS{Hdr: hdr, Ns: "localhost."}} +} + +// ServeDNS implements the plugin.Handler interface. +func (l Local) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.QName() + + lc := len("localhost.") + if len(state.Name()) > lc && strings.HasPrefix(state.Name(), "localhost.") { + // we have multiple labels, but the first one is localhost, intercept this and return 127.0.0.1 or ::1 + log.Debugf("Intercepting localhost query for %q %s, from %s", state.Name(), state.Type(), state.IP()) + LocalhostCount.Inc() + reply := doLocalhost(state) + w.WriteMsg(reply) + return 0, nil + } + + zone := plugin.Zones(zones).Matches(qname) + if zone == "" { + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) + } + + m := new(dns.Msg) + m.SetReply(r) + zone = qname[len(qname)-len(zone):] + + switch q := state.Name(); q { + case "localhost.", "0.in-addr.arpa.", "127.in-addr.arpa.", "255.in-addr.arpa.": + switch state.QType() { + case dns.TypeA: + if q != "localhost." { + // nodata + m.Ns = soaFromOrigin(qname) + break + } + + hdr := dns.RR_Header{Name: qname, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeA} + m.Answer = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP("127.0.0.1").To4()}} + case dns.TypeAAAA: + if q != "localhost." { + // nodata + m.Ns = soaFromOrigin(qname) + break + } + + hdr := dns.RR_Header{Name: qname, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeAAAA} + m.Answer = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP("::1")}} + case dns.TypeSOA: + m.Answer = soaFromOrigin(qname) + case dns.TypeNS: + m.Answer = nsFromOrigin(qname) + default: + // nodata + m.Ns = soaFromOrigin(qname) + } + case "1.0.0.127.in-addr.arpa.": + switch state.QType() { + case dns.TypePTR: + hdr := dns.RR_Header{Name: qname, Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypePTR} + m.Answer = []dns.RR{&dns.PTR{Hdr: hdr, Ptr: "localhost."}} + default: + // nodata + m.Ns = soaFromOrigin(zone) + } + } + + if len(m.Answer) == 0 && len(m.Ns) == 0 { + m.Ns = soaFromOrigin(zone) + m.Rcode = dns.RcodeNameError + } + + w.WriteMsg(m) + return 0, nil +} + +// Name implements the plugin.Handler interface. +func (l Local) Name() string { return "local" } + +func doLocalhost(state request.Request) *dns.Msg { + m := new(dns.Msg) + m.SetReply(state.Req) + switch state.QType() { + case dns.TypeA: + hdr := dns.RR_Header{Name: state.QName(), Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeA} + m.Answer = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP("127.0.0.1").To4()}} + case dns.TypeAAAA: + hdr := dns.RR_Header{Name: state.QName(), Ttl: ttl, Class: dns.ClassINET, Rrtype: dns.TypeAAAA} + m.Answer = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP("::1")}} + default: + // nodata + m.Ns = soaFromOrigin(state.QName()) + } + return m +} + +const ttl = 604800 diff --git a/plugin/local/local_test.go b/plugin/local/local_test.go new file mode 100644 index 0000000..8e1561a --- /dev/null +++ b/plugin/local/local_test.go @@ -0,0 +1,77 @@ +package local + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var testcases = []struct { + question string + qtype uint16 + rcode int + answer dns.RR + ns dns.RR +}{ + {"localhost.", dns.TypeA, dns.RcodeSuccess, test.A("localhost. IN A 127.0.0.1"), nil}, + {"localHOst.", dns.TypeA, dns.RcodeSuccess, test.A("localHOst. IN A 127.0.0.1"), nil}, + {"localhost.", dns.TypeAAAA, dns.RcodeSuccess, test.AAAA("localhost. IN AAAA ::1"), nil}, + {"localhost.", dns.TypeNS, dns.RcodeSuccess, test.NS("localhost. IN NS localhost."), nil}, + {"localhost.", dns.TypeSOA, dns.RcodeSuccess, test.SOA("localhost. IN SOA root.localhost. localhost. 1 0 0 0 0"), nil}, + {"127.in-addr.arpa.", dns.TypeA, dns.RcodeSuccess, nil, test.SOA("127.in-addr.arpa. IN SOA root.localhost. localhost. 1 0 0 0 0")}, + {"localhost.", dns.TypeMX, dns.RcodeSuccess, nil, test.SOA("localhost. IN SOA root.localhost. localhost. 1 0 0 0 0")}, + {"a.localhost.", dns.TypeA, dns.RcodeNameError, nil, test.SOA("localhost. IN SOA root.localhost. localhost. 1 0 0 0 0")}, + {"1.0.0.127.in-addr.arpa.", dns.TypePTR, dns.RcodeSuccess, test.PTR("1.0.0.127.in-addr.arpa. IN PTR localhost."), nil}, + {"1.0.0.127.in-addr.arpa.", dns.TypeMX, dns.RcodeSuccess, nil, test.SOA("127.in-addr.arpa. IN SOA root.localhost. localhost. 1 0 0 0 0")}, + {"2.0.0.127.in-addr.arpa.", dns.TypePTR, dns.RcodeNameError, nil, test.SOA("127.in-addr.arpa. IN SOA root.localhost. localhost. 1 0 0 0 0")}, + {"localhost.example.net.", dns.TypeA, dns.RcodeSuccess, test.A("localhost.example.net. IN A 127.0.0.1"), nil}, + {"localhost.example.net.", dns.TypeAAAA, dns.RcodeSuccess, test.AAAA("localhost.example.net IN AAAA ::1"), nil}, + {"localhost.example.net.", dns.TypeSOA, dns.RcodeSuccess, nil, test.SOA("localhost.example.net. IN SOA root.localhost.example.net. localhost.example.net. 1 0 0 0 0")}, +} + +func TestLocal(t *testing.T) { + req := new(dns.Msg) + l := &Local{} + + for i, tc := range testcases { + req.SetQuestion(tc.question, tc.qtype) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := l.ServeDNS(context.TODO(), rec, req) + + if err != nil { + t.Errorf("Test %d, expected no error, but got %q", i, err) + continue + } + if rec.Msg.Rcode != tc.rcode { + t.Errorf("Test %d, expected rcode %d, got %d", i, tc.rcode, rec.Msg.Rcode) + } + if tc.answer == nil && len(rec.Msg.Answer) > 0 { + t.Errorf("Test %d, expected no answer RR, got %s", i, rec.Msg.Answer[0]) + continue + } + if tc.ns == nil && len(rec.Msg.Ns) > 0 { + t.Errorf("Test %d, expected no authority RR, got %s", i, rec.Msg.Ns[0]) + continue + } + if tc.answer != nil { + if x := tc.answer.Header().Rrtype; x != rec.Msg.Answer[0].Header().Rrtype { + t.Errorf("Test %d, expected RR type %d in answer, got %d", i, x, rec.Msg.Answer[0].Header().Rrtype) + } + if x := tc.answer.Header().Name; x != rec.Msg.Answer[0].Header().Name { + t.Errorf("Test %d, expected RR name %q in answer, got %q", i, x, rec.Msg.Answer[0].Header().Name) + } + } + if tc.ns != nil { + if x := tc.ns.Header().Rrtype; x != rec.Msg.Ns[0].Header().Rrtype { + t.Errorf("Test %d, expected RR type %d in authority, got %d", i, x, rec.Msg.Ns[0].Header().Rrtype) + } + if x := tc.ns.Header().Name; x != rec.Msg.Ns[0].Header().Name { + t.Errorf("Test %d, expected RR name %q in authority, got %q", i, x, rec.Msg.Ns[0].Header().Name) + } + } + } +} diff --git a/plugin/local/metrics.go b/plugin/local/metrics.go new file mode 100644 index 0000000..361f9ab --- /dev/null +++ b/plugin/local/metrics.go @@ -0,0 +1,18 @@ +package local + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // LocalhostCount report the number of times we've seen a localhost.<domain> query. + LocalhostCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "local", + Name: "localhost_requests_total", + Help: "Counter of localhost.<domain> requests.", + }) +) diff --git a/plugin/local/setup.go b/plugin/local/setup.go new file mode 100644 index 0000000..9bd0dd6 --- /dev/null +++ b/plugin/local/setup.go @@ -0,0 +1,20 @@ +package local + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("local", setup) } + +func setup(c *caddy.Controller) error { + l := Local{} + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + l.Next = next + return l + }) + + return nil +} diff --git a/plugin/log/README.md b/plugin/log/README.md new file mode 100644 index 0000000..52dd9d7 --- /dev/null +++ b/plugin/log/README.md @@ -0,0 +1,155 @@ +# log + +## Name + +*log* - enables query logging to standard output. + +## Description + +By just using *log* you dump all queries (and parts for the reply) on standard output. Options exist +to tweak the output a little. Note that for busy servers logging will incur a performance hit. + +Enabling or disabling the *log* plugin only affects the query logging, any other logging from +CoreDNS will show up regardless. + +## Syntax + +~~~ txt +log +~~~ + +With no arguments, a query log entry is written to *stdout* in the common log format for all requests. +Or if you want/need slightly more control: + +~~~ txt +log [NAMES...] [FORMAT] +~~~ + +* `NAMES` is the name list to match in order to be logged +* `FORMAT` is the log format to use (default is Common Log Format), `{common}` is used as a shortcut + for the Common Log Format. You can also use `{combined}` for a format that adds the query opcode + `{>opcode}` to the Common Log Format. + +You can further specify the classes of responses that get logged: + +~~~ txt +log [NAMES...] [FORMAT] { + class CLASSES... +} +~~~ + +* `CLASSES` is a space-separated list of classes of responses that should be logged + +The classes of responses have the following meaning: + +* `success`: successful response +* `denial`: either NXDOMAIN or nodata responses (Name exists, type does not). A nodata response + sets the return code to NOERROR. +* `error`: SERVFAIL, NOTIMP, REFUSED, etc. Anything that indicates the remote server is not willing to + resolve the request. +* `all`: the default - nothing is specified. Using of this class means that all messages will be + logged whatever we mix together with "all". + +If no class is specified, it defaults to `all`. + +## Log Format + +You can specify a custom log format with any placeholder values. Log supports both request and +response placeholders. + +The following place holders are supported: + +* `{type}`: qtype of the request +* `{name}`: qname of the request +* `{class}`: qclass of the request +* `{proto}`: protocol used (tcp or udp) +* `{remote}`: client's IP address, for IPv6 addresses these are enclosed in brackets: `[::1]` +* `{local}`: server's IP address, for IPv6 addresses these are enclosed in brackets: `[::1]` +* `{size}`: request size in bytes +* `{port}`: client's port +* `{duration}`: response duration +* `{rcode}`: response RCODE +* `{rsize}`: raw (uncompressed), response size (a client may receive a smaller response) +* `{>rflags}`: response flags, each set flag will be displayed, e.g. "aa, tc". This includes the qr + bit as well +* `{>bufsize}`: the EDNS0 buffer size advertised in the query +* `{>do}`: is the EDNS0 DO (DNSSEC OK) bit set in the query +* `{>id}`: query ID +* `{>opcode}`: query OPCODE +* `{common}`: the default Common Log Format. +* `{combined}`: the Common Log Format with the query opcode. +* `{/LABEL}`: any metadata label is accepted as a place holder if it is enclosed between `{/` and + `}`, the place holder will be replaced by the corresponding metadata value or the default value + `-` if label is not defined. See the *metadata* plugin for more information. + +The default Common Log Format is: + +~~~ txt +`{remote}:{port} - {>id} "{type} {class} {name} {proto} {size} {>do} {>bufsize}" {rcode} {>rflags} {rsize} {duration}` +~~~ + +Each of these logs will be outputted with `log.Infof`, so a typical example looks like this: + +~~~ txt +[INFO] [::1]:50759 - 29008 "A IN example.org. udp 41 false 4096" NOERROR qr,rd,ra,ad 68 0.037990251s +~~~ + +## Examples + +Log all requests to stdout + +~~~ corefile +. { + log + whoami +} +~~~ + +Custom log format, for all zones (`.`) + +~~~ corefile +. { + log . "{proto} Request: {name} {type} {>id}" +} +~~~ + +Only log denials (NXDOMAIN and nodata) for example.org (and below) + +~~~ corefile +. { + log example.org { + class denial + } +} +~~~ + +Log all queries which were not resolved successfully in the Combined Log Format. + +~~~ corefile +. { + log . {combined} { + class denial error + } +} +~~~ + +Log all queries on which we did not get errors + +~~~ corefile +. { + log . { + class denial success + } +} +~~~ + +Also the multiple statements can be OR-ed, for example, we can rewrite the above case as following: + +~~~ corefile +. { + log . { + class denial + class success + } +} +~~~ diff --git a/plugin/log/log.go b/plugin/log/log.go new file mode 100644 index 0000000..8a3575f --- /dev/null +++ b/plugin/log/log.go @@ -0,0 +1,74 @@ +// Package log implements basic but useful request (access) logging plugin. +package log + +import ( + "context" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/replacer" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Logger is a basic request logging plugin. +type Logger struct { + Next plugin.Handler + Rules []Rule + + repl replacer.Replacer +} + +// ServeDNS implements the plugin.Handler interface. +func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + name := state.Name() + for _, rule := range l.Rules { + if !plugin.Name(rule.NameScope).Matches(name) { + continue + } + + rrw := dnstest.NewRecorder(w) + rc, err := plugin.NextOrFailure(l.Name(), l.Next, ctx, rrw, r) + + // If we don't set up a class in config, the default "all" will be added + // and we shouldn't have an empty rule.Class. + _, ok := rule.Class[response.All] + var ok1 bool + if !ok { + tpe, _ := response.Typify(rrw.Msg, time.Now().UTC()) + class := response.Classify(tpe) + _, ok1 = rule.Class[class] + } + if ok || ok1 { + logstr := l.repl.Replace(ctx, state, rrw, rule.Format) + clog.Info(logstr) + } + + return rc, err + } + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) +} + +// Name implements the Handler interface. +func (l Logger) Name() string { return "log" } + +// Rule configures the logging plugin. +type Rule struct { + NameScope string + Class map[response.Class]struct{} + Format string +} + +const ( + // CommonLogFormat is the common log format. + CommonLogFormat = `{remote}:{port} ` + replacer.EmptyValue + ` {>id} "{type} {class} {name} {proto} {size} {>do} {>bufsize}" {rcode} {>rflags} {rsize} {duration}` + // CombinedLogFormat is the combined log format. + CombinedLogFormat = CommonLogFormat + ` "{>opcode}"` + // DefaultLogFormat is the default log format. + DefaultLogFormat = CommonLogFormat +) diff --git a/plugin/log/log_test.go b/plugin/log/log_test.go new file mode 100644 index 0000000..e2f3acf --- /dev/null +++ b/plugin/log/log_test.go @@ -0,0 +1,280 @@ +package log + +import ( + "bytes" + "context" + "io" + "log" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/replacer" + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func init() { clog.Discard() } + +func TestLoggedStatus(t *testing.T) { + rule := Rule{ + NameScope: ".", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + } + + var f bytes.Buffer + log.SetOutput(&f) + + logger := Logger{ + Rules: []Rule{rule}, + Next: test.ErrorHandler(), + repl: replacer.New(), + } + + ctx := context.TODO() + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + rcode, _ := logger.ServeDNS(ctx, rec, r) + if rcode != 2 { + t.Errorf("Expected rcode to be 2 - was: %d", rcode) + } + + logged := f.String() + if !strings.Contains(logged, "A IN example.org. udp 29 false 512") { + t.Errorf("Expected it to be logged. Logged string: %s", logged) + } +} + +func TestLoggedClassDenial(t *testing.T) { + rule := Rule{ + NameScope: ".", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.Denial: {}}, + } + + var f bytes.Buffer + log.SetOutput(&f) + + logger := Logger{ + Rules: []Rule{rule}, + Next: test.ErrorHandler(), + repl: replacer.New(), + } + + ctx := context.TODO() + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + logger.ServeDNS(ctx, rec, r) + + logged := f.String() + if len(logged) != 0 { + t.Errorf("Expected it not to be logged, but got string: %s", logged) + } +} + +func TestLoggedClassError(t *testing.T) { + rule := Rule{ + NameScope: ".", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.Error: {}}, + } + + var f bytes.Buffer + log.SetOutput(&f) + + logger := Logger{ + Rules: []Rule{rule}, + Next: test.ErrorHandler(), + repl: replacer.New(), + } + + ctx := context.TODO() + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + logger.ServeDNS(ctx, rec, r) + + logged := f.String() + if !strings.Contains(logged, "SERVFAIL") { + t.Errorf("Expected it to be logged. Logged string: %s", logged) + } +} + +func TestLogged(t *testing.T) { + tests := []struct { + Rules []Rule + Domain string + ShouldLog bool + ShouldString string + ShouldNOTString string // for test format + }{ + // case for NameScope + { + Rules: []Rule{ + { + NameScope: "example.org.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "example.org.", + ShouldLog: true, + ShouldString: "A IN example.org.", + }, + { + Rules: []Rule{ + { + NameScope: "example.org.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "example.net.", + ShouldLog: false, + ShouldString: "", + }, + { + Rules: []Rule{ + { + NameScope: "example.org.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, + { + NameScope: "example.net.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "example.net.", + ShouldLog: true, + ShouldString: "A IN example.net.", + }, + + // case for format + { + Rules: []Rule{ + { + NameScope: ".", + Format: "{type} {class}", + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "example.org.", + ShouldLog: true, + ShouldString: "A IN", + ShouldNOTString: "example.org", + }, + { + Rules: []Rule{ + { + NameScope: ".", + Format: "{remote}:{port}", + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "example.org.", + ShouldLog: true, + ShouldString: "10.240.0.1:40212", + ShouldNOTString: "A IN example.org", + }, + { + Rules: []Rule{ + { + NameScope: ".", + Format: CombinedLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "example.org.", + ShouldLog: true, + ShouldString: "\"0\"", + }, + { + Rules: []Rule{ + { + NameScope: ".", + Format: CombinedLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, + }, + Domain: "foo.%s.example.org.", + ShouldLog: true, + ShouldString: "foo.%s.example.org.", + ShouldNOTString: "%!s(MISSING)", + }, + } + + for _, tc := range tests { + var f bytes.Buffer + log.SetOutput(&f) + + logger := Logger{ + Rules: tc.Rules, + Next: test.ErrorHandler(), + repl: replacer.New(), + } + + ctx := context.TODO() + r := new(dns.Msg) + r.SetQuestion(tc.Domain, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := logger.ServeDNS(ctx, rec, r) + if err != nil { + t.Fatal(err) + } + + logged := f.String() + + if !tc.ShouldLog && len(logged) != 0 { + t.Errorf("Expected it not to be logged, but got string: %s", logged) + } + if tc.ShouldLog && !strings.Contains(logged, tc.ShouldString) { + t.Errorf("Expected it to contains: %s. Logged string: %s", tc.ShouldString, logged) + } + if tc.ShouldLog && tc.ShouldNOTString != "" && strings.Contains(logged, tc.ShouldNOTString) { + t.Errorf("Expected it to NOT contains: %s. Logged string: %s", tc.ShouldNOTString, logged) + } + } +} + +func BenchmarkLogged(b *testing.B) { + log.SetOutput(io.Discard) + + rule := Rule{ + NameScope: ".", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + } + + logger := Logger{ + Rules: []Rule{rule}, + Next: test.ErrorHandler(), + repl: replacer.New(), + } + + ctx := context.TODO() + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + b.StartTimer() + for i := 0; i < b.N; i++ { + logger.ServeDNS(ctx, rec, r) + } +} diff --git a/plugin/log/setup.go b/plugin/log/setup.go new file mode 100644 index 0000000..e1d9913 --- /dev/null +++ b/plugin/log/setup.go @@ -0,0 +1,102 @@ +package log + +import ( + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/replacer" + "github.com/coredns/coredns/plugin/pkg/response" + + "github.com/miekg/dns" +) + +func init() { plugin.Register("log", setup) } + +func setup(c *caddy.Controller) error { + rules, err := logParse(c) + if err != nil { + return plugin.Error("log", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Logger{Next: next, Rules: rules, repl: replacer.New()} + }) + + return nil +} + +func logParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule + + for c.Next() { + args := c.RemainingArgs() + length := len(rules) + + switch len(args) { + case 0: + // Nothing specified; use defaults + rules = append(rules, Rule{ + NameScope: ".", + Format: DefaultLogFormat, + Class: make(map[response.Class]struct{}), + }) + case 1: + rules = append(rules, Rule{ + NameScope: dns.Fqdn(args[0]), + Format: DefaultLogFormat, + Class: make(map[response.Class]struct{}), + }) + default: + // Name scopes, and maybe a format specified + format := DefaultLogFormat + + if strings.Contains(args[len(args)-1], "{") { + format = args[len(args)-1] + format = strings.Replace(format, "{common}", CommonLogFormat, -1) + format = strings.Replace(format, "{combined}", CombinedLogFormat, -1) + args = args[:len(args)-1] + } + + for _, str := range args { + rules = append(rules, Rule{ + NameScope: dns.Fqdn(str), + Format: format, + Class: make(map[response.Class]struct{}), + }) + } + } + + // Class refinements in an extra block. + classes := make(map[response.Class]struct{}) + for c.NextBlock() { + switch c.Val() { + // class followed by combinations of all, denial, error and success. + case "class": + classesArgs := c.RemainingArgs() + if len(classesArgs) == 0 { + return nil, c.ArgErr() + } + for _, c := range classesArgs { + cls, err := response.ClassFromString(c) + if err != nil { + return nil, err + } + classes[cls] = struct{}{} + } + default: + return nil, c.ArgErr() + } + } + if len(classes) == 0 { + classes[response.All] = struct{}{} + } + + for i := len(rules) - 1; i >= length; i-- { + rules[i].Class = classes + } + } + + return rules, nil +} diff --git a/plugin/log/setup_test.go b/plugin/log/setup_test.go new file mode 100644 index 0000000..2586ade --- /dev/null +++ b/plugin/log/setup_test.go @@ -0,0 +1,184 @@ +package log + +import ( + "reflect" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/response" +) + +func TestLogParse(t *testing.T) { + tests := []struct { + inputLogRules string + shouldErr bool + expectedLogRules []Rule + }{ + {`log`, false, []Rule{{ + NameScope: ".", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org`, false, []Rule{{ + NameScope: "example.org.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org. {common}`, false, []Rule{{ + NameScope: "example.org.", + Format: CommonLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org {combined}`, false, []Rule{{ + NameScope: "example.org.", + Format: CombinedLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org. + log example.net {combined}`, false, []Rule{{ + NameScope: "example.org.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, { + NameScope: "example.net.", + Format: CombinedLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org {host} + log example.org {when}`, false, []Rule{{ + NameScope: "example.org.", + Format: "{host}", + Class: map[response.Class]struct{}{response.All: {}}, + }, { + NameScope: "example.org.", + Format: "{when}", + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org example.net`, false, []Rule{{ + NameScope: "example.org.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }, { + NameScope: "example.net.", + Format: DefaultLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org example.net {host}`, false, []Rule{{ + NameScope: "example.org.", + Format: "{host}", + Class: map[response.Class]struct{}{response.All: {}}, + }, { + NameScope: "example.net.", + Format: "{host}", + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org example.net {when} { + class denial + }`, false, []Rule{{ + NameScope: "example.org.", + Format: "{when}", + Class: map[response.Class]struct{}{response.Denial: {}}, + }, { + NameScope: "example.net.", + Format: "{when}", + Class: map[response.Class]struct{}{response.Denial: {}}, + }}}, + + {`log example.org { + class all + }`, false, []Rule{{ + NameScope: "example.org.", + Format: CommonLogFormat, + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org { + class denial + }`, false, []Rule{{ + NameScope: "example.org.", + Format: CommonLogFormat, + Class: map[response.Class]struct{}{response.Denial: {}}, + }}}, + {`log { + class denial + }`, false, []Rule{{ + NameScope: ".", + Format: CommonLogFormat, + Class: map[response.Class]struct{}{response.Denial: {}}, + }}}, + {`log { + class denial error + }`, false, []Rule{{ + NameScope: ".", + Format: CommonLogFormat, + Class: map[response.Class]struct{}{response.Denial: {}, response.Error: {}}, + }}}, + {`log { + class denial + class error + }`, false, []Rule{{ + NameScope: ".", + Format: CommonLogFormat, + Class: map[response.Class]struct{}{response.Denial: {}, response.Error: {}}, + }}}, + {`log { + class abracadabra + }`, true, []Rule{}}, + {`log { + class + }`, true, []Rule{}}, + {`log { + unknown + }`, true, []Rule{}}, + {`log example.org "{combined} {/forward/upstream}"`, false, []Rule{{ + NameScope: "example.org.", + Format: CombinedLogFormat + " {/forward/upstream}", + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org "{common} {/forward/upstream}"`, false, []Rule{{ + NameScope: "example.org.", + Format: CommonLogFormat + " {/forward/upstream}", + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org "{when} {combined} {/forward/upstream}"`, false, []Rule{{ + NameScope: "example.org.", + Format: "{when} " + CombinedLogFormat + " {/forward/upstream}", + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + {`log example.org "{when} {common} {/forward/upstream}"`, false, []Rule{{ + NameScope: "example.org.", + Format: "{when} " + CommonLogFormat + " {/forward/upstream}", + Class: map[response.Class]struct{}{response.All: {}}, + }}}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputLogRules) + actualLogRules, err := logParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d with input '%s' didn't error, but it should have", i, test.inputLogRules) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d with input '%s' errored, but it shouldn't have; got '%v'", + i, test.inputLogRules, err) + } + if len(actualLogRules) != len(test.expectedLogRules) { + t.Fatalf("Test %d expected %d no of Log rules, but got %d", + i, len(test.expectedLogRules), len(actualLogRules)) + } + for j, actualLogRule := range actualLogRules { + if actualLogRule.NameScope != test.expectedLogRules[j].NameScope { + t.Errorf("Test %d expected %dth LogRule NameScope for '%s' to be %s , but got %s", + i, j, test.inputLogRules, test.expectedLogRules[j].NameScope, actualLogRule.NameScope) + } + + if actualLogRule.Format != test.expectedLogRules[j].Format { + t.Errorf("Test %d expected %dth LogRule Format for '%s' to be %s , but got %s", + i, j, test.inputLogRules, test.expectedLogRules[j].Format, actualLogRule.Format) + } + + if !reflect.DeepEqual(actualLogRule.Class, test.expectedLogRules[j].Class) { + t.Errorf("Test %d expected %dth LogRule Class to be %v , but got %v", + i, j, test.expectedLogRules[j].Class, actualLogRule.Class) + } + } + } +} diff --git a/plugin/log_test.go b/plugin/log_test.go new file mode 100644 index 0000000..0ee4b7c --- /dev/null +++ b/plugin/log_test.go @@ -0,0 +1,5 @@ +package plugin + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/loop/README.md b/plugin/loop/README.md new file mode 100644 index 0000000..826f5c5 --- /dev/null +++ b/plugin/loop/README.md @@ -0,0 +1,93 @@ +# loop + +## Name + +*loop* - detects simple forwarding loops and halts the server. + +## Description + +The *loop* plugin will send a random probe query to ourselves and will then keep track of how many times +we see it. If we see it more than twice, we assume CoreDNS has seen a forwarding loop and we halt the process. + +The plugin will try to send the query for up to 30 seconds. This is done to give CoreDNS enough time +to start up. Once a query has been successfully sent, *loop* disables itself to prevent a query of +death. + +Note that *loop* will _only_ send "looping queries" for the first zone given in the Server Block. + +The query sent is `<random number>.<random number>.zone` with type set to HINFO. + +## Syntax + +~~~ txt +loop +~~~ + +## Examples + +Start a server on the default port and load the *loop* and *forward* plugins. The *forward* plugin +forwards to it self. + +~~~ txt +. { + loop + forward . 127.0.0.1 +} +~~~ + +After CoreDNS has started it stops the process while logging: + +~~~ txt +plugin/loop: Loop (127.0.0.1:55953 -> :1053) detected for zone ".", see https://coredns.io/plugins/loop#troubleshooting. Query: "HINFO 4547991504243258144.3688648895315093531." +~~~ + +## Limitations + +This plugin only attempts to find simple static forwarding loops at start up time. To detect a loop, +the following must be true: + +* the loop must be present at start up time. + +* the loop must occur for the `HINFO` query type. + +## Troubleshooting + +When CoreDNS logs contain the message `Loop ... detected ...`, this means that the `loop` detection +plugin has detected an infinite forwarding loop in one of the upstream DNS servers. This is a fatal +error because operating with an infinite loop will consume memory and CPU until eventual out of +memory death by the host. + +A forwarding loop is usually caused by: + +* Most commonly, CoreDNS forwarding requests directly to itself. e.g. via a loopback address such as `127.0.0.1`, `::1` or `127.0.0.53` +* Less commonly, CoreDNS forwarding to an upstream server that in turn, forwards requests back to CoreDNS. + +To troubleshoot this problem, look in your Corefile for any `forward`s to the zone +in which the loop was detected. Make sure that they are not forwarding to a local address or +to another DNS server that is forwarding requests back to CoreDNS. If `forward` is +using a file (e.g. `/etc/resolv.conf`), make sure that file does not contain local addresses. + +### Troubleshooting Loops In Kubernetes Clusters + +When a CoreDNS Pod deployed in Kubernetes detects a loop, the CoreDNS Pod will start to "CrashLoopBackOff". +This is because Kubernetes will try to restart the Pod every time CoreDNS detects the loop and exits. + +A common cause of forwarding loops in Kubernetes clusters is an interaction with a local DNS cache +on the host node (e.g. `systemd-resolved`). For example, in certain configurations `systemd-resolved` will +put the loopback address `127.0.0.53` as a nameserver into `/etc/resolv.conf`. Kubernetes (via `kubelet`) by default +will pass this `/etc/resolv.conf` file to all Pods using the `default` dnsPolicy rendering them +unable to make DNS lookups (this includes CoreDNS Pods). CoreDNS uses this `/etc/resolv.conf` +as a list of upstreams to forward requests to. Since it contains a loopback address, CoreDNS ends up forwarding +requests to itself. + +There are many ways to work around this issue, some are listed here: + +* Add the following to your `kubelet` config yaml: `resolvConf: <path-to-your-real-resolv-conf-file>` (or via command line flag `--resolv-conf` deprecated in 1.10). Your "real" + `resolv.conf` is the one that contains the actual IPs of your upstream servers, and no local/loopback address. + This flag tells `kubelet` to pass an alternate `resolv.conf` to Pods. For systems using `systemd-resolved`, +`/run/systemd/resolve/resolv.conf` is typically the location of the "real" `resolv.conf`, +although this can be different depending on your distribution. +* Disable the local DNS cache on host nodes, and restore `/etc/resolv.conf` to the original. +* A quick and dirty fix is to edit your Corefile, replacing `forward . /etc/resolv.conf` with +the IP address of your upstream DNS, for example `forward . 8.8.8.8`. But this only fixes the issue for CoreDNS, +kubelet will continue to forward the invalid `resolv.conf` to all `default` dnsPolicy Pods, leaving them unable to resolve DNS. diff --git a/plugin/loop/log_test.go b/plugin/loop/log_test.go new file mode 100644 index 0000000..882b5c8 --- /dev/null +++ b/plugin/loop/log_test.go @@ -0,0 +1,5 @@ +package loop + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/loop/loop.go b/plugin/loop/loop.go new file mode 100644 index 0000000..8d29798 --- /dev/null +++ b/plugin/loop/loop.go @@ -0,0 +1,109 @@ +package loop + +import ( + "context" + "sync" + + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("loop") + +// Loop is a plugin that implements loop detection by sending a "random" query. +type Loop struct { + Next plugin.Handler + + zone string + qname string + addr string + + sync.RWMutex + i int + off bool +} + +// New returns a new initialized Loop. +func New(zone string) *Loop { return &Loop{zone: zone, qname: qname(zone)} } + +// ServeDNS implements the plugin.Handler interface. +func (l *Loop) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if r.Question[0].Qtype != dns.TypeHINFO { + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) + } + if l.disabled() { + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) + } + + state := request.Request{W: w, Req: r} + + zone := plugin.Zones([]string{l.zone}).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) + } + + if state.Name() == l.qname { + l.inc() + } + + if l.seen() > 2 { + log.Fatalf(`Loop (%s -> %s) detected for zone %q, see https://coredns.io/plugins/loop#troubleshooting. Query: "HINFO %s"`, state.RemoteAddr(), l.address(), l.zone, l.qname) + } + + return plugin.NextOrFailure(l.Name(), l.Next, ctx, w, r) +} + +// Name implements the plugin.Handler interface. +func (l *Loop) Name() string { return "loop" } + +func (l *Loop) exchange(addr string) (*dns.Msg, error) { + m := new(dns.Msg) + m.SetQuestion(l.qname, dns.TypeHINFO) + + return dns.Exchange(m, addr) +} + +func (l *Loop) seen() int { + l.RLock() + defer l.RUnlock() + return l.i +} + +func (l *Loop) inc() { + l.Lock() + defer l.Unlock() + l.i++ +} + +func (l *Loop) reset() { + l.Lock() + defer l.Unlock() + l.i = 0 +} + +func (l *Loop) setDisabled() { + l.Lock() + defer l.Unlock() + l.off = true +} + +func (l *Loop) disabled() bool { + l.RLock() + defer l.RUnlock() + return l.off +} + +func (l *Loop) setAddress(addr string) { + l.Lock() + defer l.Unlock() + l.addr = addr +} + +func (l *Loop) address() string { + l.RLock() + defer l.RUnlock() + return l.addr +} diff --git a/plugin/loop/loop_test.go b/plugin/loop/loop_test.go new file mode 100644 index 0000000..e7a4b06 --- /dev/null +++ b/plugin/loop/loop_test.go @@ -0,0 +1,11 @@ +package loop + +import "testing" + +func TestLoop(t *testing.T) { + l := New(".") + l.inc() + if l.seen() != 1 { + t.Errorf("Failed to inc loop, expected %d, got %d", 1, l.seen()) + } +} diff --git a/plugin/loop/setup.go b/plugin/loop/setup.go new file mode 100644 index 0000000..4e076c6 --- /dev/null +++ b/plugin/loop/setup.go @@ -0,0 +1,87 @@ +package loop + +import ( + "net" + "strconv" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnsutil" + "github.com/coredns/coredns/plugin/pkg/rand" +) + +func init() { plugin.Register("loop", setup) } + +func setup(c *caddy.Controller) error { + l, err := parse(c) + if err != nil { + return plugin.Error("loop", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + l.Next = next + return l + }) + + // Send query to ourselves and see if it end up with us again. + c.OnStartup(func() error { + // Another Go function, otherwise we block startup and can't send the packet. + go func() { + deadline := time.Now().Add(30 * time.Second) + conf := dnsserver.GetConfig(c) + lh := conf.ListenHosts[0] + addr := net.JoinHostPort(lh, conf.Port) + + for time.Now().Before(deadline) { + l.setAddress(addr) + if _, err := l.exchange(addr); err != nil { + l.reset() + time.Sleep(1 * time.Second) + continue + } + + go func() { + time.Sleep(2 * time.Second) + l.setDisabled() + }() + + break + } + l.setDisabled() + }() + return nil + }) + + return nil +} + +func parse(c *caddy.Controller) (*Loop, error) { + i := 0 + zones := []string{"."} + for c.Next() { + if i > 0 { + return nil, plugin.ErrOnce + } + i++ + if c.NextArg() { + return nil, c.ArgErr() + } + + if len(c.ServerBlockKeys) > 0 { + zones = plugin.Host(c.ServerBlockKeys[0]).NormalizeExact() + } + } + return New(zones[0]), nil +} + +// qname returns a random name. <rand.Int()>.<rand.Int().<zone>. +func qname(zone string) string { + l1 := strconv.Itoa(r.Int()) + l2 := strconv.Itoa(r.Int()) + + return dnsutil.Join(l1, l2, zone) +} + +var r = rand.New(time.Now().UnixNano()) diff --git a/plugin/loop/setup_test.go b/plugin/loop/setup_test.go new file mode 100644 index 0000000..6b80b9b --- /dev/null +++ b/plugin/loop/setup_test.go @@ -0,0 +1,19 @@ +package loop + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", `loop`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `loop argument`) + if err := setup(c); err == nil { + t.Fatal("Expected errors, but got none") + } +} diff --git a/plugin/metadata/README.md b/plugin/metadata/README.md new file mode 100644 index 0000000..6eb2c39 --- /dev/null +++ b/plugin/metadata/README.md @@ -0,0 +1,49 @@ +# metadata + +## Name + +*metadata* - enables a metadata collector. + +## Description + +By enabling *metadata* any plugin that implements [metadata.Provider +interface](https://godoc.org/github.com/coredns/coredns/plugin/metadata#Provider) will be called for +each DNS query, at the beginning of the process for that query, in order to add its own metadata to +context. + +The metadata collected will be available for all plugins, via the Context parameter provided in the +ServeDNS function. The package (code) documentation has examples on how to inspect and retrieve +metadata a plugin might be interested in. + +The metadata is added by setting a label with a value in the context. These labels should be named +`plugin/NAME`, where **NAME** is something descriptive. The only hard requirement the *metadata* +plugin enforces is that the labels contain a slash. See the documentation for +`metadata.SetValueFunc`. + +The value stored is a string. The empty string signals "no metadata". See the documentation for +`metadata.ValueFunc` on how to retrieve this. + +## Syntax + +~~~ +metadata [ZONES... ] +~~~ + +* **ZONES** zones metadata should be invoked for. + +## Plugins + +`metadata.Provider` interface needs to be implemented by each plugin willing to provide metadata +information for other plugins. It will be called by metadata and gather the information from all +plugins in context. + +Note: this method should work quickly, because it is called for every request. + +## Examples + +The *rewrite* plugin uses meta data to rewrite requests. + +## See Also + +The [Provider interface](https://godoc.org/github.com/coredns/coredns/plugin/metadata#Provider) and +the [package level](https://godoc.org/github.com/coredns/coredns/plugin/metadata) documentation. diff --git a/plugin/metadata/log_test.go b/plugin/metadata/log_test.go new file mode 100644 index 0000000..8d1e924 --- /dev/null +++ b/plugin/metadata/log_test.go @@ -0,0 +1,5 @@ +package metadata + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/metadata/metadata.go b/plugin/metadata/metadata.go new file mode 100644 index 0000000..58e5ce2 --- /dev/null +++ b/plugin/metadata/metadata.go @@ -0,0 +1,44 @@ +package metadata + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Metadata implements collecting metadata information from all plugins that +// implement the Provider interface. +type Metadata struct { + Zones []string + Providers []Provider + Next plugin.Handler +} + +// Name implements the Handler interface. +func (m *Metadata) Name() string { return "metadata" } + +// ContextWithMetadata is exported for use by provider tests +func ContextWithMetadata(ctx context.Context) context.Context { + return context.WithValue(ctx, key{}, md{}) +} + +// ServeDNS implements the plugin.Handler interface. +func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, w, r) + return rcode, err +} + +// Collect will retrieve metadata functions from each metadata provider and update the context +func (m *Metadata) Collect(ctx context.Context, state request.Request) context.Context { + ctx = ContextWithMetadata(ctx) + if plugin.Zones(m.Zones).Matches(state.Name()) != "" { + // Go through all Providers and collect metadata. + for _, p := range m.Providers { + ctx = p.Metadata(ctx, state) + } + } + return ctx +} diff --git a/plugin/metadata/metadata_test.go b/plugin/metadata/metadata_test.go new file mode 100644 index 0000000..6b8da6d --- /dev/null +++ b/plugin/metadata/metadata_test.go @@ -0,0 +1,93 @@ +package metadata + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type testProvider map[string]Func + +func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context { + for k, v := range tp { + SetValueFunc(ctx, k, v) + } + return ctx +} + +type testHandler struct{ ctx context.Context } + +func (m *testHandler) Name() string { return "test" } + +func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m.ctx = ctx + return 0, nil +} + +func TestMetadataServeDNS(t *testing.T) { + expectedMetadata := []testProvider{ + {"test/key1": func() string { return "testvalue1" }}, + {"test/key2": func() string { return "two" }, "test/key3": func() string { return "testvalue3" }}, + } + // Create fake Providers based on expectedMetadata + providers := []Provider{} + for _, e := range expectedMetadata { + providers = append(providers, e) + } + + next := &testHandler{} // fake handler which stores the resulting context + m := Metadata{ + Zones: []string{"."}, + Providers: providers, + Next: next, + } + + ctx := context.TODO() + w := &test.ResponseWriter{} + r := new(dns.Msg) + ctx = m.Collect(ctx, request.Request{W: w, Req: r}) + m.ServeDNS(ctx, w, r) + nctx := next.ctx + + for _, expected := range expectedMetadata { + for label, expVal := range expected { + if !IsLabel(label) { + t.Errorf("Expected label %s is not considered a valid label", label) + } + val := ValueFunc(nctx, label) + if val() != expVal() { + t.Errorf("Expected value %s for %s, but got %s", expVal(), label, val()) + } + } + } +} + +func TestLabelFormat(t *testing.T) { + labels := []struct { + label string + isValid bool + }{ + // ok + {"plugin/LABEL", true}, + {"p/LABEL", true}, + {"plugin/L", true}, + {"PLUGIN/LABEL/SUB-LABEL", true}, + // fails + {"LABEL", false}, + {"plugin.LABEL", false}, + {"/NO-PLUGIN-NOT-ACCEPTED", false}, + {"ONLY-PLUGIN-NOT-ACCEPTED/", false}, + {"/", false}, + {"//", false}, + } + + for _, test := range labels { + if x := IsLabel(test.label); x != test.isValid { + t.Errorf("Label %v expected %v, got: %v", test.label, test.isValid, x) + } + } +} diff --git a/plugin/metadata/provider.go b/plugin/metadata/provider.go new file mode 100644 index 0000000..2e88d58 --- /dev/null +++ b/plugin/metadata/provider.go @@ -0,0 +1,126 @@ +// Package metadata provides an API that allows plugins to add metadata to the context. +// Each metadata is stored under a label that has the form <plugin>/<name>. Each metadata +// is returned as a Func. When Func is called the metadata is returned. If Func is expensive to +// execute it is its responsibility to provide some form of caching. During the handling of a +// query it is expected the metadata stays constant. +// +// Basic example: +// +// Implement the Provider interface for a plugin p: +// +// func (p P) Metadata(ctx context.Context, state request.Request) context.Context { +// metadata.SetValueFunc(ctx, "test/something", func() string { return "myvalue" }) +// return ctx +// } +// +// Basic example with caching: +// +// func (p P) Metadata(ctx context.Context, state request.Request) context.Context { +// cached := "" +// f := func() string { +// if cached != "" { +// return cached +// } +// cached = expensiveFunc() +// return cached +// } +// metadata.SetValueFunc(ctx, "test/something", f) +// return ctx +// } +// +// If you need access to this metadata from another plugin: +// +// // ... +// valueFunc := metadata.ValueFunc(ctx, "test/something") +// value := valueFunc() +// // use 'value' +package metadata + +import ( + "context" + "strings" + + "github.com/coredns/coredns/request" +) + +// Provider interface needs to be implemented by each plugin willing to provide +// metadata information for other plugins. +type Provider interface { + // Metadata adds metadata to the context and returns a (potentially) new context. + // Note: this method should work quickly, because it is called for every request + // from the metadata plugin. + Metadata(ctx context.Context, state request.Request) context.Context +} + +// Func is the type of function in the metadata, when called they return the value of the label. +type Func func() string + +// IsLabel checks that the provided name is a valid label name, i.e. two or more words separated by a slash. +func IsLabel(label string) bool { + p := strings.Index(label, "/") + if p <= 0 || p >= len(label)-1 { + // cannot accept namespace empty nor label empty + return false + } + return true +} + +// Labels returns all metadata keys stored in the context. These label names should be named +// as: plugin/NAME, where NAME is something descriptive. +func Labels(ctx context.Context) []string { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + return keys(m) + } + } + return nil +} + +// ValueFuncs returns the map[string]Func from the context, or nil if it does not exist. +func ValueFuncs(ctx context.Context) map[string]Func { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + return m + } + } + return nil +} + +// ValueFunc returns the value function of label. If none can be found nil is returned. Calling the +// function returns the value of the label. +func ValueFunc(ctx context.Context, label string) Func { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + return m[label] + } + } + return nil +} + +// SetValueFunc set the metadata label to the value function. If no metadata can be found this is a noop and +// false is returned. Any existing value is overwritten. +func SetValueFunc(ctx context.Context, label string, f Func) bool { + if metadata := ctx.Value(key{}); metadata != nil { + if m, ok := metadata.(md); ok { + m[label] = f + return true + } + } + return false +} + +// md is metadata information storage. +type md map[string]Func + +// key defines the type of key that is used to save metadata into the context. +type key struct{} + +func keys(m map[string]Func) []string { + s := make([]string, len(m)) + i := 0 + for k := range m { + s[i] = k + i++ + } + return s +} diff --git a/plugin/metadata/setup.go b/plugin/metadata/setup.go new file mode 100644 index 0000000..90b1cf9 --- /dev/null +++ b/plugin/metadata/setup.go @@ -0,0 +1,44 @@ +package metadata + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("metadata", setup) } + +func setup(c *caddy.Controller) error { + m, err := metadataParse(c) + if err != nil { + return err + } + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + m.Next = next + return m + }) + + c.OnStartup(func() error { + plugins := dnsserver.GetConfig(c).Handlers() + for _, p := range plugins { + if met, ok := p.(Provider); ok { + m.Providers = append(m.Providers, met) + } + } + return nil + }) + + return nil +} + +func metadataParse(c *caddy.Controller) (*Metadata, error) { + m := &Metadata{} + c.Next() + + m.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + + if c.NextBlock() || c.Next() { + return nil, plugin.Error("metadata", c.ArgErr()) + } + return m, nil +} diff --git a/plugin/metadata/setup_test.go b/plugin/metadata/setup_test.go new file mode 100644 index 0000000..ed552f7 --- /dev/null +++ b/plugin/metadata/setup_test.go @@ -0,0 +1,70 @@ +package metadata + +import ( + "reflect" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + zones []string + shouldErr bool + }{ + {"metadata", []string{}, false}, + {"metadata example.com.", []string{"example.com."}, false}, + {"metadata example.com. net.", []string{"example.com.", "net."}, false}, + + {"metadata example.com. { some_param }", []string{}, true}, + {"metadata\nmetadata", []string{}, true}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Setup call expected error but found none for input %s", i, test.input) + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Setup call expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + } +} + +func TestSetupHealth(t *testing.T) { + tests := []struct { + input string + zones []string + shouldErr bool + }{ + {"metadata", []string{}, false}, + {"metadata example.com.", []string{"example.com."}, false}, + {"metadata example.com. net.", []string{"example.com.", "net."}, false}, + + {"metadata example.com. { some_param }", []string{}, true}, + {"metadata\nmetadata", []string{}, true}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + m, err := metadataParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !test.shouldErr && err == nil { + if !reflect.DeepEqual(test.zones, m.Zones) { + t.Errorf("Test %d: Expected zones %s. Zones were: %v", i, test.zones, m.Zones) + } + } + } +} diff --git a/plugin/metrics/README.md b/plugin/metrics/README.md new file mode 100644 index 0000000..144a5d1 --- /dev/null +++ b/plugin/metrics/README.md @@ -0,0 +1,91 @@ +# prometheus + +## Name + +*prometheus* - enables [Prometheus](https://prometheus.io/) metrics. + +## Description + +With *prometheus* you export metrics from CoreDNS and any plugin that has them. +The default location for the metrics is `localhost:9153`. The metrics path is fixed to `/metrics`. + +In addition to the default Go metrics exported by the [Prometheus Go client](https://prometheus.io/docs/guides/go-application/), +the following metrics are exported: + +* `coredns_build_info{version, revision, goversion}` - info about CoreDNS itself. +* `coredns_panics_total{}` - total number of panics. +* `coredns_dns_requests_total{server, zone, view, proto, family, type}` - total query count. +* `coredns_dns_request_duration_seconds{server, zone, view, type}` - duration to process each query. +* `coredns_dns_request_size_bytes{server, zone, view, proto}` - size of the request in bytes. +* `coredns_dns_do_requests_total{server, view, zone}` - queries that have the DO bit set +* `coredns_dns_response_size_bytes{server, zone, view, proto}` - response size in bytes. +* `coredns_dns_responses_total{server, zone, view, rcode, plugin}` - response per zone, rcode and plugin. +* `coredns_dns_https_responses_total{server, status}` - responses per server and http status code. +* `coredns_dns_quic_responses_total{server, status}` - responses per server and QUIC application code. +* `coredns_plugin_enabled{server, zone, view, name}` - indicates whether a plugin is enabled on per server, zone and view basis. + +Almost each counter has a label `zone` which is the zonename used for the request/response. + +Extra labels used are: + +* `server` is identifying the server responsible for the request. This is a string formatted + as the server's listening address: `<scheme>://[<bind>]:<port>`. I.e. for a "normal" DNS server + this is `dns://:53`. If you are using the *bind* plugin an IP address is included, e.g.: `dns://127.0.0.53:53`. +* `proto` which holds the transport of the response ("udp" or "tcp") +* The address family (`family`) of the transport (1 = IP (IP version 4), 2 = IP6 (IP version 6)). +* `type` which holds the query type. It holds most common types (A, AAAA, MX, SOA, CNAME, PTR, TXT, + NS, SRV, DS, DNSKEY, RRSIG, NSEC, NSEC3, HTTPS, IXFR, AXFR and ANY) and "other" which lumps together all + other types. +* `status` which holds the https status code. Possible values are: + * 200 - request is processed, + * 404 - request has been rejected on validation, + * 400 - request to dns message conversion failed, + * 500 - processing ended up with no response. +* the `plugin` label holds the name of the plugin that made the write to the client. If the server + did the write (on error for instance), the value is empty. + +If monitoring is enabled, queries that do not enter the plugin chain are exported under the fake +name "dropped" (without a closing dot - this is never a valid domain name). + +Other plugins may export additional stats when the _prometheus_ plugin is enabled. Those stats are documented in each +plugin's README. + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ +prometheus [ADDRESS] +~~~ + +For each zone that you want to see metrics for. + +It optionally takes a bind address to which the metrics are exported; the default +listens on `localhost:9153`. The metrics path is fixed to `/metrics`. + +## Examples + +Use an alternative listening address: + +~~~ corefile +. { + prometheus localhost:9253 +} +~~~ + +Or via an environment variable (this is supported throughout the Corefile): `export PORT=9253`, and +then: + +~~~ corefile +. { + prometheus localhost:{$PORT} +} +~~~ + +## Bugs + +When reloading, the Prometheus handler is stopped before the new server instance is started. +If that new server fails to start, then the initial server instance is still available and DNS queries still served, +but Prometheus handler stays down. +Prometheus will not reply HTTP request until a successful reload or a complete restart of CoreDNS. +Only the plugins that register as Handler are visible in `coredns_plugin_enabled{server, zone, name}`. As of today the plugins reload and bind will not be reported. diff --git a/plugin/metrics/context.go b/plugin/metrics/context.go new file mode 100644 index 0000000..ae2856d --- /dev/null +++ b/plugin/metrics/context.go @@ -0,0 +1,37 @@ +package metrics + +import ( + "context" + + "github.com/coredns/coredns/core/dnsserver" +) + +// WithServer returns the current server handling the request. It returns the +// server listening address: <scheme>://[<bind>]:<port> Normally this is +// something like "dns://:53", but if the bind plugin is used, i.e. "bind +// 127.0.0.53", it will be "dns://127.0.0.53:53", etc. If not address is found +// the empty string is returned. +// +// Basic usage with a metric: +// +// <metric>.WithLabelValues(metrics.WithServer(ctx), labels..).Add(1) +func WithServer(ctx context.Context) string { + srv := ctx.Value(dnsserver.Key{}) + if srv == nil { + return "" + } + return srv.(*dnsserver.Server).Addr +} + +// WithView returns the name of the view currently handling the request, if a view is defined. +// +// Basic usage with a metric: +// +// <metric>.WithLabelValues(metrics.WithView(ctx), labels..).Add(1) +func WithView(ctx context.Context) string { + v := ctx.Value(dnsserver.ViewKey{}) + if v == nil { + return "" + } + return v.(string) +} diff --git a/plugin/metrics/handler.go b/plugin/metrics/handler.go new file mode 100644 index 0000000..41da690 --- /dev/null +++ b/plugin/metrics/handler.go @@ -0,0 +1,57 @@ +package metrics + +import ( + "context" + "path/filepath" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics/vars" + "github.com/coredns/coredns/plugin/pkg/rcode" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// ServeDNS implements the Handler interface. +func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + qname := state.QName() + zone := plugin.Zones(m.ZoneNames()).Matches(qname) + if zone == "" { + zone = "." + } + + // Record response to get status code and size of the reply. + rw := NewRecorder(w) + status, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, rw, r) + + rc := rw.Rcode + if !plugin.ClientWrite(status) { + // when no response was written, fallback to status returned from next plugin as this status + // is actually used as rcode of DNS response + // see https://github.com/coredns/coredns/blob/master/core/dnsserver/server.go#L318 + rc = status + } + plugin := m.authoritativePlugin(rw.Caller) + vars.Report(WithServer(ctx), state, zone, WithView(ctx), rcode.ToString(rc), plugin, rw.Len, rw.Start) + + return status, err +} + +// Name implements the Handler interface. +func (m *Metrics) Name() string { return "prometheus" } + +// authoritativePlugin returns which of made the write, if none is found the empty string is returned. +func (m *Metrics) authoritativePlugin(caller [3]string) string { + // a b and c contain the full path of the caller, the plugin name 2nd last elements + // .../coredns/plugin/whoami/whoami.go --> whoami + // this is likely FS specific, so use filepath. + for _, c := range caller { + plug := filepath.Base(filepath.Dir(c)) + if _, ok := m.plugins[plug]; ok { + return plug + } + } + return "" +} diff --git a/plugin/metrics/log_test.go b/plugin/metrics/log_test.go new file mode 100644 index 0000000..101098a --- /dev/null +++ b/plugin/metrics/log_test.go @@ -0,0 +1,5 @@ +package metrics + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/metrics/metrics.go b/plugin/metrics/metrics.go new file mode 100644 index 0000000..6a9e652 --- /dev/null +++ b/plugin/metrics/metrics.go @@ -0,0 +1,172 @@ +// Package metrics implement a handler and plugin that provides Prometheus metrics. +package metrics + +import ( + "context" + "net" + "net/http" + "sync" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/reuseport" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics . +type Metrics struct { + Next plugin.Handler + Addr string + Reg *prometheus.Registry + + ln net.Listener + lnSetup bool + + mux *http.ServeMux + srv *http.Server + + zoneNames []string + zoneMap map[string]struct{} + zoneMu sync.RWMutex + + plugins map[string]struct{} // all available plugins, used to determine which plugin made the client write +} + +// New returns a new instance of Metrics with the given address. +func New(addr string) *Metrics { + met := &Metrics{ + Addr: addr, + Reg: prometheus.DefaultRegisterer.(*prometheus.Registry), + zoneMap: make(map[string]struct{}), + plugins: pluginList(caddy.ListPlugins()), + } + + return met +} + +// MustRegister wraps m.Reg.MustRegister. +func (m *Metrics) MustRegister(c prometheus.Collector) { + err := m.Reg.Register(c) + if err != nil { + // ignore any duplicate error, but fatal on any other kind of error + if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { + log.Fatalf("Cannot register metrics collector: %s", err) + } + } +} + +// AddZone adds zone z to m. +func (m *Metrics) AddZone(z string) { + m.zoneMu.Lock() + m.zoneMap[z] = struct{}{} + m.zoneNames = keys(m.zoneMap) + m.zoneMu.Unlock() +} + +// RemoveZone remove zone z from m. +func (m *Metrics) RemoveZone(z string) { + m.zoneMu.Lock() + delete(m.zoneMap, z) + m.zoneNames = keys(m.zoneMap) + m.zoneMu.Unlock() +} + +// ZoneNames returns the zones of m. +func (m *Metrics) ZoneNames() []string { + m.zoneMu.RLock() + s := m.zoneNames + m.zoneMu.RUnlock() + return s +} + +// OnStartup sets up the metrics on startup. +func (m *Metrics) OnStartup() error { + ln, err := reuseport.Listen("tcp", m.Addr) + if err != nil { + log.Errorf("Failed to start metrics handler: %s", err) + return err + } + + m.ln = ln + m.lnSetup = true + + m.mux = http.NewServeMux() + m.mux.Handle("/metrics", promhttp.HandlerFor(m.Reg, promhttp.HandlerOpts{})) + + // creating some helper variables to avoid data races on m.srv and m.ln + server := &http.Server{Handler: m.mux} + m.srv = server + + go func() { + server.Serve(ln) + }() + + ListenAddr = ln.Addr().String() // For tests. + return nil +} + +// OnRestart stops the listener on reload. +func (m *Metrics) OnRestart() error { + if !m.lnSetup { + return nil + } + u.Unset(m.Addr) + return m.stopServer() +} + +func (m *Metrics) stopServer() error { + if !m.lnSetup { + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := m.srv.Shutdown(ctx); err != nil { + log.Infof("Failed to stop prometheus http server: %s", err) + return err + } + m.lnSetup = false + m.ln.Close() + return nil +} + +// OnFinalShutdown tears down the metrics listener on shutdown and restart. +func (m *Metrics) OnFinalShutdown() error { return m.stopServer() } + +func keys(m map[string]struct{}) []string { + sx := []string{} + for k := range m { + sx = append(sx, k) + } + return sx +} + +// pluginList iterates over the returned plugin map from caddy and removes the "dns." prefix from them. +func pluginList(m map[string][]string) map[string]struct{} { + pm := map[string]struct{}{} + for _, p := range m["others"] { + // only add 'dns.' plugins + if len(p) > 3 { + pm[p[4:]] = struct{}{} + continue + } + } + return pm +} + +// ListenAddr is assigned the address of the prometheus listener. Its use is mainly in tests where +// we listen on "localhost:0" and need to retrieve the actual address. +var ListenAddr string + +// shutdownTimeout is the maximum amount of time the metrics plugin will wait +// before erroring when it tries to close the metrics server +const shutdownTimeout time.Duration = time.Second * 5 + +var buildInfo = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Name: "build_info", + Help: "A metric with a constant '1' value labeled by version, revision, and goversion from which CoreDNS was built.", +}, []string{"version", "revision", "goversion"}) diff --git a/plugin/metrics/metrics_test.go b/plugin/metrics/metrics_test.go new file mode 100644 index 0000000..bd72bf1 --- /dev/null +++ b/plugin/metrics/metrics_test.go @@ -0,0 +1,82 @@ +package metrics + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestMetrics(t *testing.T) { + met := New("localhost:0") + if err := met.OnStartup(); err != nil { + t.Fatalf("Failed to start metrics handler: %s", err) + } + defer met.OnFinalShutdown() + + met.AddZone("example.org.") + + tests := []struct { + next plugin.Handler + qname string + qtype uint16 + metric string + expectedValue string + }{ + // This all works because 1 bucket (1 zone, 1 type) + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "example.org.", + metric: "coredns_dns_requests_total", + expectedValue: "1", + }, + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "example.org.", + metric: "coredns_dns_requests_total", + expectedValue: "2", + }, + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "example.org.", + metric: "coredns_dns_requests_total", + expectedValue: "3", + }, + { + next: test.NextHandler(dns.RcodeSuccess, nil), + qname: "example.org.", + metric: "coredns_dns_responses_total", + expectedValue: "4", + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + if tc.qtype == 0 { + tc.qtype = dns.TypeA + } + req.SetQuestion(tc.qname, tc.qtype) + met.Next = tc.next + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := met.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("Test %d: Expected no error, but got %s", i, err) + } + + result := test.Scrape("http://" + ListenAddr + "/metrics") + + if tc.expectedValue != "" { + got, _ := test.MetricValue(tc.metric, result) + if got != tc.expectedValue { + t.Errorf("Test %d: Expected value %s for metrics %s, but got %s", i, tc.expectedValue, tc.metric, got) + } + } + } +} diff --git a/plugin/metrics/recorder.go b/plugin/metrics/recorder.go new file mode 100644 index 0000000..d4d42ba --- /dev/null +++ b/plugin/metrics/recorder.go @@ -0,0 +1,28 @@ +package metrics + +import ( + "runtime" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + + "github.com/miekg/dns" +) + +// Recorder is a dnstest.Recorder specific to the metrics plugin. +type Recorder struct { + *dnstest.Recorder + // CallerN holds the string return value of the call to runtime.Caller(N+1) + Caller [3]string +} + +// NewRecorder makes and returns a new Recorder. +func NewRecorder(w dns.ResponseWriter) *Recorder { return &Recorder{Recorder: dnstest.NewRecorder(w)} } + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *Recorder) WriteMsg(res *dns.Msg) error { + _, r.Caller[0], _, _ = runtime.Caller(1) + _, r.Caller[1], _, _ = runtime.Caller(2) + _, r.Caller[2], _, _ = runtime.Caller(3) + return r.Recorder.WriteMsg(res) +} diff --git a/plugin/metrics/recorder_test.go b/plugin/metrics/recorder_test.go new file mode 100644 index 0000000..fd8c5fc --- /dev/null +++ b/plugin/metrics/recorder_test.go @@ -0,0 +1,68 @@ +package metrics + +import ( + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +type inmemoryWriter struct { + test.ResponseWriter + written []byte +} + +func (r *inmemoryWriter) WriteMsg(m *dns.Msg) error { + r.written, _ = m.Pack() + return r.ResponseWriter.WriteMsg(m) +} + +func (r *inmemoryWriter) Write(buf []byte) (int, error) { + r.written = buf + return r.ResponseWriter.Write(buf) +} + +func TestRecorder_WriteMsg(t *testing.T) { + successResp := dns.Msg{} + successResp.Answer = []dns.RR{ + test.A("a.example.org. 1800 IN A 127.0.0.53"), + } + + nxdomainResp := dns.Msg{} + nxdomainResp.Rcode = dns.RcodeNameError + + tests := []struct { + name string + msg *dns.Msg + }{ + { + name: "should record successful response", + msg: &successResp, + }, + { + name: "should record nxdomain response", + msg: &nxdomainResp, + }, + } + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tw := inmemoryWriter{ResponseWriter: test.ResponseWriter{}} + rec := NewRecorder(&tw) + + if err := rec.WriteMsg(tt.msg); err != nil { + t.Errorf("Test %d: WriteMsg() unexpected error %v", i, err) + } + + if rec.Msg != tt.msg { + t.Errorf("Test %d: Expected value %v for msg, but got %v", i, tt.msg, rec.Msg) + } + if rec.Len != tt.msg.Len() { + t.Errorf("Test %d: Expected value %d for len, but got %d", i, tt.msg.Len(), rec.Len) + } + if rec.Rcode != tt.msg.Rcode { + t.Errorf("Test %d: Expected value %d for rcode, but got %d", i, tt.msg.Rcode, rec.Rcode) + } + }) + } +} diff --git a/plugin/metrics/registry.go b/plugin/metrics/registry.go new file mode 100644 index 0000000..2d6a92e --- /dev/null +++ b/plugin/metrics/registry.go @@ -0,0 +1,28 @@ +package metrics + +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" +) + +type reg struct { + sync.RWMutex + r map[string]*prometheus.Registry +} + +func newReg() *reg { return ®{r: make(map[string]*prometheus.Registry)} } + +// update sets the registry if not already there and returns the input. Or it returns +// a previous set value. +func (r *reg) getOrSet(addr string, pr *prometheus.Registry) *prometheus.Registry { + r.Lock() + defer r.Unlock() + + if v, ok := r.r[addr]; ok { + return v + } + + r.r[addr] = pr + return pr +} diff --git a/plugin/metrics/setup.go b/plugin/metrics/setup.go new file mode 100644 index 0000000..bee7d1f --- /dev/null +++ b/plugin/metrics/setup.go @@ -0,0 +1,105 @@ +package metrics + +import ( + "net" + "runtime" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/coremain" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics/vars" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/uniq" +) + +var ( + log = clog.NewWithPlugin("prometheus") + u = uniq.New() + registry = newReg() +) + +func init() { plugin.Register("prometheus", setup) } + +func setup(c *caddy.Controller) error { + m, err := parse(c) + if err != nil { + return plugin.Error("prometheus", err) + } + m.Reg = registry.getOrSet(m.Addr, m.Reg) + + c.OnStartup(func() error { m.Reg = registry.getOrSet(m.Addr, m.Reg); u.Set(m.Addr, m.OnStartup); return nil }) + c.OnRestartFailed(func() error { m.Reg = registry.getOrSet(m.Addr, m.Reg); u.Set(m.Addr, m.OnStartup); return nil }) + + c.OnStartup(func() error { return u.ForEach() }) + c.OnRestartFailed(func() error { return u.ForEach() }) + + c.OnStartup(func() error { + conf := dnsserver.GetConfig(c) + for _, h := range conf.ListenHosts { + addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port) + for _, p := range conf.Handlers() { + vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, conf.ViewName, p.Name()).Set(1) + } + } + return nil + }) + c.OnRestartFailed(func() error { + conf := dnsserver.GetConfig(c) + for _, h := range conf.ListenHosts { + addrstr := conf.Transport + "://" + net.JoinHostPort(h, conf.Port) + for _, p := range conf.Handlers() { + vars.PluginEnabled.WithLabelValues(addrstr, conf.Zone, conf.ViewName, p.Name()).Set(1) + } + } + return nil + }) + + c.OnRestart(m.OnRestart) + c.OnRestart(func() error { vars.PluginEnabled.Reset(); return nil }) + c.OnFinalShutdown(m.OnFinalShutdown) + + // Initialize metrics. + buildInfo.WithLabelValues(coremain.CoreVersion, coremain.GitCommit, runtime.Version()).Set(1) + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + m.Next = next + return m + }) + + return nil +} + +func parse(c *caddy.Controller) (*Metrics, error) { + met := New(defaultAddr) + + i := 0 + for c.Next() { + if i > 0 { + return nil, plugin.ErrOnce + } + i++ + + zones := plugin.OriginsFromArgsOrServerBlock(nil /* args */, c.ServerBlockKeys) + for _, z := range zones { + met.AddZone(z) + } + args := c.RemainingArgs() + + switch len(args) { + case 0: + case 1: + met.Addr = args[0] + _, _, e := net.SplitHostPort(met.Addr) + if e != nil { + return met, e + } + default: + return met, c.ArgErr() + } + } + return met, nil +} + +// defaultAddr is the address the where the metrics are exported by default. +const defaultAddr = "localhost:9153" diff --git a/plugin/metrics/setup_test.go b/plugin/metrics/setup_test.go new file mode 100644 index 0000000..3a584a6 --- /dev/null +++ b/plugin/metrics/setup_test.go @@ -0,0 +1,42 @@ +package metrics + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestPrometheusParse(t *testing.T) { + tests := []struct { + input string + shouldErr bool + addr string + }{ + // oks + {`prometheus`, false, "localhost:9153"}, + {`prometheus localhost:53`, false, "localhost:53"}, + // fails + {`prometheus {}`, true, ""}, + {`prometheus /foo`, true, ""}, + {`prometheus a b c`, true, ""}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + m, err := parse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + + if test.shouldErr { + continue + } + + if test.addr != m.Addr { + t.Errorf("Test %v: Expected address %s but found: %s", i, test.addr, m.Addr) + } + } +} diff --git a/plugin/metrics/vars/monitor.go b/plugin/metrics/vars/monitor.go new file mode 100644 index 0000000..191324e --- /dev/null +++ b/plugin/metrics/vars/monitor.go @@ -0,0 +1,36 @@ +package vars + +import ( + "github.com/miekg/dns" +) + +var monitorType = map[uint16]struct{}{ + dns.TypeAAAA: {}, + dns.TypeA: {}, + dns.TypeCNAME: {}, + dns.TypeDNSKEY: {}, + dns.TypeDS: {}, + dns.TypeMX: {}, + dns.TypeNSEC3: {}, + dns.TypeNSEC: {}, + dns.TypeNS: {}, + dns.TypePTR: {}, + dns.TypeRRSIG: {}, + dns.TypeSOA: {}, + dns.TypeSRV: {}, + dns.TypeTXT: {}, + dns.TypeHTTPS: {}, + // Meta Qtypes + dns.TypeIXFR: {}, + dns.TypeAXFR: {}, + dns.TypeANY: {}, +} + +// qTypeString returns the RR type based on monitorType. It returns the text representation +// of those types. RR types not in that list will have "other" returned. +func qTypeString(qtype uint16) string { + if _, known := monitorType[qtype]; known { + return dns.Type(qtype).String() + } + return "other" +} diff --git a/plugin/metrics/vars/report.go b/plugin/metrics/vars/report.go new file mode 100644 index 0000000..92f6bc1 --- /dev/null +++ b/plugin/metrics/vars/report.go @@ -0,0 +1,33 @@ +package vars + +import ( + "time" + + "github.com/coredns/coredns/request" +) + +// Report reports the metrics data associated with request. This function is exported because it is also +// called from core/dnsserver to report requests hitting the server that should not be handled and are thus +// not sent down the plugin chain. +func Report(server string, req request.Request, zone, view, rcode, plugin string, size int, start time.Time) { + // Proto and Family. + net := req.Proto() + fam := "1" + if req.Family() == 2 { + fam = "2" + } + + if req.Do() { + RequestDo.WithLabelValues(server, zone, view).Inc() + } + + qType := qTypeString(req.QType()) + RequestCount.WithLabelValues(server, zone, view, net, fam, qType).Inc() + + RequestDuration.WithLabelValues(server, zone, view).Observe(time.Since(start).Seconds()) + + ResponseSize.WithLabelValues(server, zone, view, net).Observe(float64(size)) + RequestSize.WithLabelValues(server, zone, view, net).Observe(float64(req.Len())) + + ResponseRcode.WithLabelValues(server, zone, view, rcode, plugin).Inc() +} diff --git a/plugin/metrics/vars/vars.go b/plugin/metrics/vars/vars.go new file mode 100644 index 0000000..6de75c0 --- /dev/null +++ b/plugin/metrics/vars/vars.go @@ -0,0 +1,89 @@ +package vars + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Request* and Response* are the prometheus counters and gauges we are using for exporting metrics. +var ( + RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "requests_total", + Help: "Counter of DNS requests made per zone, protocol and family.", + }, []string{"server", "zone", "view", "proto", "family", "type"}) + + RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time (in seconds) each request took per zone.", + }, []string{"server", "zone", "view"}) + + RequestSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "request_size_bytes", + Help: "Size of the EDNS0 UDP buffer in bytes (64K for TCP) per zone and protocol.", + Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3}, + }, []string{"server", "zone", "view", "proto"}) + + RequestDo = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "do_requests_total", + Help: "Counter of DNS requests with DO bit set per zone.", + }, []string{"server", "zone", "view"}) + + ResponseSize = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "response_size_bytes", + Help: "Size of the returned response in bytes.", + Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3}, + }, []string{"server", "zone", "view", "proto"}) + + ResponseRcode = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "responses_total", + Help: "Counter of response status codes.", + }, []string{"server", "zone", "view", "rcode", "plugin"}) + + Panic = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Name: "panics_total", + Help: "A metrics that counts the number of panics.", + }) + + PluginEnabled = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Name: "plugin_enabled", + Help: "A metric that indicates whether a plugin is enabled on per server and zone basis.", + }, []string{"server", "zone", "view", "name"}) + + HTTPSResponsesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "https_responses_total", + Help: "Counter of DoH responses per server and http status code.", + }, []string{"server", "status"}) + + QUICResponsesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: subsystem, + Name: "quic_responses_total", + Help: "Counter of DoQ responses per server and QUIC application code.", + }, []string{"server", "status"}) +) + +const ( + subsystem = "dns" + + // Dropped indicates we dropped the query before any handling. It has no closing dot, so it can not be a valid zone. + Dropped = "dropped" +) diff --git a/plugin/minimal/README.md b/plugin/minimal/README.md new file mode 100644 index 0000000..a225743 --- /dev/null +++ b/plugin/minimal/README.md @@ -0,0 +1,36 @@ +# minimal + +## Name + +*minimal* - minimizes size of the DNS response message whenever possible. + +## Description + +The *minimal* plugin tries to minimize the size of the response. Depending on the response type it +removes resource records from the AUTHORITY and ADDITIONAL sections. + +Specifically this plugin looks at successful responses (this excludes negative responses, i.e. +nodata or name error). If the successful response isn't a delegation only the RRs in the answer +section are written to the client. + +## Syntax + +~~~ txt +minimal +~~~ + +## Examples + +Enable minimal responses: + +~~~ corefile +example.org { + whoami + forward . 8.8.8.8 + minimal +} +~~~ + +## See Also + +[BIND 9 Configuration Reference](https://bind9.readthedocs.io/en/latest/reference.html#boolean-options) diff --git a/plugin/minimal/minimal.go b/plugin/minimal/minimal.go new file mode 100644 index 0000000..0bac6a3 --- /dev/null +++ b/plugin/minimal/minimal.go @@ -0,0 +1,55 @@ +package minimal + +import ( + "context" + "fmt" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/nonwriter" + "github.com/coredns/coredns/plugin/pkg/response" + + "github.com/miekg/dns" +) + +// minimalHandler implements the plugin.Handler interface. +type minimalHandler struct { + Next plugin.Handler +} + +func (m *minimalHandler) Name() string { return "minimal" } + +func (m *minimalHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + nw := nonwriter.New(w) + + rcode, err := plugin.NextOrFailure(m.Name(), m.Next, ctx, nw, r) + if err != nil { + return rcode, err + } + + ty, _ := response.Typify(nw.Msg, time.Now().UTC()) + cl := response.Classify(ty) + + // if response is Denial or Error pass through also if the type is Delegation pass through + if cl == response.Denial || cl == response.Error || ty == response.Delegation { + w.WriteMsg(nw.Msg) + return 0, nil + } + if ty != response.NoError { + w.WriteMsg(nw.Msg) + return 0, plugin.Error("minimal", fmt.Errorf("unhandled response type %q for %q", ty, nw.Msg.Question[0].Name)) + } + + // copy over the original Msg params, deep copy not required as RRs are not modified + d := &dns.Msg{ + MsgHdr: nw.Msg.MsgHdr, + Compress: nw.Msg.Compress, + Question: nw.Msg.Question, + Answer: nw.Msg.Answer, + Ns: nil, + Extra: nil, + } + + w.WriteMsg(d) + return 0, nil +} diff --git a/plugin/minimal/minimal_test.go b/plugin/minimal/minimal_test.go new file mode 100644 index 0000000..406d787 --- /dev/null +++ b/plugin/minimal/minimal_test.go @@ -0,0 +1,153 @@ +package minimal + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// testHandler implements plugin.Handler and will be used to create a stub handler for the test +type testHandler struct { + Response *test.Case + Next plugin.Handler +} + +func (t *testHandler) Name() string { return "test-handler" } + +func (t *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + d := new(dns.Msg) + d.SetReply(r) + if t.Response != nil { + d.Answer = t.Response.Answer + d.Ns = t.Response.Ns + d.Extra = t.Response.Extra + d.Rcode = t.Response.Rcode + } + w.WriteMsg(d) + return 0, nil +} + +func TestMinimizeResponse(t *testing.T) { + baseAnswer := []dns.RR{ + test.A("example.com. 293 IN A 142.250.76.46"), + } + baseNs := []dns.RR{ + test.NS("example.com. 157127 IN NS ns2.example.com."), + test.NS("example.com. 157127 IN NS ns1.example.com."), + test.NS("example.com. 157127 IN NS ns3.example.com."), + test.NS("example.com. 157127 IN NS ns4.example.com."), + } + + baseExtra := []dns.RR{ + test.A("ns2.example.com. 316273 IN A 216.239.34.10"), + test.AAAA("ns2.example.com. 157127 IN AAAA 2001:4860:4802:34::a"), + test.A("ns3.example.com. 316274 IN A 216.239.36.10"), + test.AAAA("ns3.example.com. 157127 IN AAAA 2001:4860:4802:36::a"), + test.A("ns1.example.com. 165555 IN A 216.239.32.10"), + test.AAAA("ns1.example.com. 165555 IN AAAA 2001:4860:4802:32::a"), + test.A("ns4.example.com. 190188 IN A 216.239.38.10"), + test.AAAA("ns4.example.com. 157127 IN AAAA 2001:4860:4802:38::a"), + } + + tests := []struct { + active bool + original test.Case + minimal test.Case + }{ + { // minimization possible NoError case + original: test.Case{ + Answer: baseAnswer, + Ns: nil, + Extra: baseExtra, + Rcode: 0, + }, + minimal: test.Case{ + Answer: baseAnswer, + Ns: nil, + Extra: nil, + Rcode: 0, + }, + }, + { // delegate response case + original: test.Case{ + Answer: nil, + Ns: baseNs, + Extra: baseExtra, + Rcode: 0, + }, + minimal: test.Case{ + Answer: nil, + Ns: baseNs, + Extra: baseExtra, + Rcode: 0, + }, + }, { // negative response case + original: test.Case{ + Answer: baseAnswer, + Ns: baseNs, + Extra: baseExtra, + Rcode: 2, + }, + minimal: test.Case{ + Answer: baseAnswer, + Ns: baseNs, + Extra: baseExtra, + Rcode: 2, + }, + }, + } + + for i, tc := range tests { + req := new(dns.Msg) + req.SetQuestion("example.com", dns.TypeA) + + tHandler := &testHandler{ + Response: &tc.original, + Next: nil, + } + o := &minimalHandler{Next: tHandler} + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := o.ServeDNS(context.TODO(), rec, req) + + if err != nil { + t.Errorf("Expected no error, but got %q", err) + } + + if len(tc.minimal.Answer) != len(rec.Msg.Answer) { + t.Errorf("Test %d: Expected %d Answer, but got %d", i, len(tc.minimal.Answer), len(req.Answer)) + continue + } + if len(tc.minimal.Ns) != len(rec.Msg.Ns) { + t.Errorf("Test %d: Expected %d Ns, but got %d", i, len(tc.minimal.Ns), len(req.Ns)) + continue + } + + if len(tc.minimal.Extra) != len(rec.Msg.Extra) { + t.Errorf("Test %d: Expected %d Extras, but got %d", i, len(tc.minimal.Extra), len(req.Extra)) + continue + } + + for j, a := range rec.Msg.Answer { + if tc.minimal.Answer[j].String() != a.String() { + t.Errorf("Test %d: Expected Answer %d to be %v, but got %v", i, j, tc.minimal.Answer[j], a) + } + } + + for j, a := range rec.Msg.Ns { + if tc.minimal.Ns[j].String() != a.String() { + t.Errorf("Test %d: Expected NS %d to be %v, but got %v", i, j, tc.minimal.Ns[j], a) + } + } + + for j, a := range rec.Msg.Extra { + if tc.minimal.Extra[j].String() != a.String() { + t.Errorf("Test %d: Expected Extra %d to be %v, but got %v", i, j, tc.minimal.Extra[j], a) + } + } + } +} diff --git a/plugin/minimal/setup.go b/plugin/minimal/setup.go new file mode 100644 index 0000000..1bf37a6 --- /dev/null +++ b/plugin/minimal/setup.go @@ -0,0 +1,24 @@ +package minimal + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { + plugin.Register("minimal", setup) +} + +func setup(c *caddy.Controller) error { + c.Next() + if c.NextArg() { + return plugin.Error("minimal", c.ArgErr()) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return &minimalHandler{Next: next} + }) + + return nil +} diff --git a/plugin/minimal/setup_test.go b/plugin/minimal/setup_test.go new file mode 100644 index 0000000..49341c4 --- /dev/null +++ b/plugin/minimal/setup_test.go @@ -0,0 +1,19 @@ +package minimal + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", `minimal-response`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `minimal-response example.org`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } +} diff --git a/plugin/normalize.go b/plugin/normalize.go new file mode 100644 index 0000000..4b92bb4 --- /dev/null +++ b/plugin/normalize.go @@ -0,0 +1,196 @@ +package plugin + +import ( + "fmt" + "net" + "runtime" + "strconv" + "strings" + + "github.com/coredns/coredns/plugin/pkg/cidr" + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/parse" + + "github.com/miekg/dns" +) + +// See core/dnsserver/address.go - we should unify these two impls. + +// Zones represents a lists of zone names. +type Zones []string + +// Matches checks if qname is a subdomain of any of the zones in z. The match +// will return the most specific zones that matches. The empty string +// signals a not found condition. +func (z Zones) Matches(qname string) string { + zone := "" + for _, zname := range z { + if dns.IsSubDomain(zname, qname) { + // We want the *longest* matching zone, otherwise we may end up in a parent + if len(zname) > len(zone) { + zone = zname + } + } + } + return zone +} + +// Normalize fully qualifies all zones in z. The zones in Z must be domain names, without +// a port or protocol prefix. +func (z Zones) Normalize() { + for i := range z { + z[i] = Name(z[i]).Normalize() + } +} + +// Name represents a domain name. +type Name string + +// Matches checks to see if other is a subdomain (or the same domain) of n. +// This method assures that names can be easily and consistently matched. +func (n Name) Matches(child string) bool { + if dns.Name(n) == dns.Name(child) { + return true + } + return dns.IsSubDomain(string(n), child) +} + +// Normalize lowercases and makes n fully qualified. +func (n Name) Normalize() string { return strings.ToLower(dns.Fqdn(string(n))) } + +type ( + // Host represents a host from the Corefile, may contain port. + Host string +) + +// Normalize will return the host portion of host, stripping +// of any port or transport. The host will also be fully qualified and lowercased. +// An empty string is returned on failure +// Deprecated: use OriginsFromArgsOrServerBlock or NormalizeExact +func (h Host) Normalize() string { + var caller string + if _, file, line, ok := runtime.Caller(1); ok { + caller = fmt.Sprintf("(%v line %d) ", file, line) + } + log.Warning("An external plugin " + caller + "is using the deprecated function Normalize. " + + "This will be removed in a future versions of CoreDNS. The plugin should be updated to use " + + "OriginsFromArgsOrServerBlock or NormalizeExact instead.") + + s := string(h) + _, s = parse.Transport(s) + + // The error can be ignored here, because this function is called after the corefile has already been vetted. + hosts, _, err := SplitHostPort(s) + if err != nil { + return "" + } + return Name(hosts[0]).Normalize() +} + +// MustNormalize will return the host portion of host, stripping +// of any port or transport. The host will also be fully qualified and lowercased. +// An error is returned on error +// Deprecated: use OriginsFromArgsOrServerBlock or NormalizeExact +func (h Host) MustNormalize() (string, error) { + var caller string + if _, file, line, ok := runtime.Caller(1); ok { + caller = fmt.Sprintf("(%v line %d) ", file, line) + } + log.Warning("An external plugin " + caller + "is using the deprecated function MustNormalize. " + + "This will be removed in a future versions of CoreDNS. The plugin should be updated to use " + + "OriginsFromArgsOrServerBlock or NormalizeExact instead.") + + s := string(h) + _, s = parse.Transport(s) + + // The error can be ignored here, because this function is called after the corefile has already been vetted. + hosts, _, err := SplitHostPort(s) + if err != nil { + return "", err + } + return Name(hosts[0]).Normalize(), nil +} + +// NormalizeExact will return the host portion of host, stripping +// of any port or transport. The host will also be fully qualified and lowercased. +// An empty slice is returned on failure +func (h Host) NormalizeExact() []string { + // The error can be ignored here, because this function should only be called after the corefile has already been vetted. + s := string(h) + _, s = parse.Transport(s) + + hosts, _, err := SplitHostPort(s) + if err != nil { + return nil + } + for i := range hosts { + hosts[i] = Name(hosts[i]).Normalize() + } + return hosts +} + +// SplitHostPort splits s up in a host(s) and port portion, taking reverse address notation into account. +// String the string s should *not* be prefixed with any protocols, i.e. dns://. SplitHostPort can return +// multiple hosts when a reverse notation on a non-octet boundary is given. +func SplitHostPort(s string) (hosts []string, port string, err error) { + // If there is: :[0-9]+ on the end we assume this is the port. This works for (ascii) domain + // names and our reverse syntax, which always needs a /mask *before* the port. + // So from the back, find first colon, and then check if it's a number. + colon := strings.LastIndex(s, ":") + if colon == len(s)-1 { + return nil, "", fmt.Errorf("expecting data after last colon: %q", s) + } + if colon != -1 { + if p, err := strconv.Atoi(s[colon+1:]); err == nil { + port = strconv.Itoa(p) + s = s[:colon] + } + } + + // TODO(miek): this should take escaping into account. + if len(s) > 255 { + return nil, "", fmt.Errorf("specified zone is too long: %d > 255", len(s)) + } + + if _, ok := dns.IsDomainName(s); !ok { + return nil, "", fmt.Errorf("zone is not a valid domain name: %s", s) + } + + // Check if it parses as a reverse zone, if so we use that. Must be fully specified IP and mask. + _, n, err := net.ParseCIDR(s) + if err != nil { + return []string{s}, port, nil + } + + if s[0] == ':' || (s[0] == '0' && strings.Contains(s, ":")) { + return nil, "", fmt.Errorf("invalid CIDR %s", s) + } + + // now check if multiple hosts must be returned. + nets := cidr.Split(n) + hosts = cidr.Reverse(nets) + return hosts, port, nil +} + +// OriginsFromArgsOrServerBlock returns the normalized args if that slice +// is not empty, otherwise the serverblock slice is returned (in a newly copied slice). +func OriginsFromArgsOrServerBlock(args, serverblock []string) []string { + if len(args) == 0 { + s := make([]string, len(serverblock)) + copy(s, serverblock) + for i := range s { + s[i] = Host(s[i]).NormalizeExact()[0] // expansion of these already happened in dnsserver/register.go + } + return s + } + s := []string{} + for i := range args { + sx := Host(args[i]).NormalizeExact() + if len(sx) == 0 { + continue // silently ignores errors. + } + s = append(s, sx...) + } + + return s +} diff --git a/plugin/normalize_test.go b/plugin/normalize_test.go new file mode 100644 index 0000000..cc32eae --- /dev/null +++ b/plugin/normalize_test.go @@ -0,0 +1,140 @@ +package plugin + +import ( + "sort" + "testing" +) + +func TestZoneMatches(t *testing.T) { + child := "example.org." + zones := Zones([]string{"org.", "."}) + actual := zones.Matches(child) + if actual != "org." { + t.Errorf("Expected %v, got %v", "org.", actual) + } + + child = "bla.example.org." + zones = Zones([]string{"bla.example.org.", "org.", "."}) + actual = zones.Matches(child) + + if actual != "bla.example.org." { + t.Errorf("Expected %v, got %v", "org.", actual) + } +} + +func TestZoneNormalize(t *testing.T) { + zones := Zones([]string{"example.org", "Example.ORG.", "example.org."}) + expected := "example.org." + zones.Normalize() + + for _, actual := range zones { + if actual != expected { + t.Errorf("Expected %v, got %v", expected, actual) + } + } +} + +func TestNameMatches(t *testing.T) { + matches := []struct { + child string + parent string + expected bool + }{ + {".", ".", true}, + {"example.org.", ".", true}, + {"example.org.", "example.org.", true}, + {"example.org.", "org.", true}, + {"org.", "example.org.", false}, + } + + for _, m := range matches { + actual := Name(m.parent).Matches(m.child) + if actual != m.expected { + t.Errorf("Expected %v for %s/%s, got %v", m.expected, m.parent, m.child, actual) + } + } +} + +func TestNameNormalize(t *testing.T) { + names := []string{ + "example.org", "example.org.", + "Example.ORG.", "example.org."} + + for i := 0; i < len(names); i += 2 { + ts := names[i] + expected := names[i+1] + actual := Name(ts).Normalize() + if expected != actual { + t.Errorf("Expected %v, got %v", expected, actual) + } + } +} + +func TestHostNormalizeExact(t *testing.T) { + tests := []struct { + in string + out []string + }{ + {".:53", []string{"."}}, + {"example.org:53", []string{"example.org."}}, + {"example.org.:53", []string{"example.org."}}, + {"10.0.0.0/8:53", []string{"10.in-addr.arpa."}}, + {"10.0.0.0/15", []string{"0.10.in-addr.arpa.", "1.10.in-addr.arpa."}}, + {"10.9.3.0/18", []string{"0.9.10.in-addr.arpa.", "1.9.10.in-addr.arpa.", "2.9.10.in-addr.arpa."}}, + {"2001:db8::/29", []string{ + "8.b.d.0.1.0.0.2.ip6.arpa.", + "9.b.d.0.1.0.0.2.ip6.arpa.", + "a.b.d.0.1.0.0.2.ip6.arpa.", + "b.b.d.0.1.0.0.2.ip6.arpa.", + "c.b.d.0.1.0.0.2.ip6.arpa.", + "d.b.d.0.1.0.0.2.ip6.arpa.", + "e.b.d.0.1.0.0.2.ip6.arpa.", + "f.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"2001:db8::/30", []string{ + "8.b.d.0.1.0.0.2.ip6.arpa.", + "9.b.d.0.1.0.0.2.ip6.arpa.", + "a.b.d.0.1.0.0.2.ip6.arpa.", + "b.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"2001:db8::/115", []string{ + "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"2001:db8::/114", []string{ + "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "3.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"2001:db8::/113", []string{ + "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "2.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "3.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "4.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "6.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "7.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"2001:db8::/112", []string{ + "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"2001:db8::/108", []string{ + "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + }}, + {"::fFFF:B:F/115", nil}, + {"dns://example.org", []string{"example.org."}}, + } + + for i := range tests { + actual := Host(tests[i].in).NormalizeExact() + expected := tests[i].out + sort.Strings(expected) + for j := range expected { + if expected[j] != actual[j] { + t.Errorf("Test %d, expected %v, got %v", i, expected[j], actual[j]) + } + } + } +} diff --git a/plugin/nsid/README.md b/plugin/nsid/README.md new file mode 100644 index 0000000..7bb15ca --- /dev/null +++ b/plugin/nsid/README.md @@ -0,0 +1,57 @@ +# nsid + +## Name + +*nsid* - adds an identifier of this server to each reply. + +## Description + +This plugin implements [RFC 5001](https://tools.ietf.org/html/rfc5001) and adds an EDNS0 OPT +resource record to replies that uniquely identify the server. This is useful in anycast setups to +see which server was responsible for generating the reply and for debugging. + +This plugin can only be used once per Server Block. + + +## Syntax + +~~~ txt +nsid [DATA] +~~~ + +**DATA** is the string to use in the nsid record. + +If **DATA** is not given, the host's name is used. + +## Examples + +Enable nsid: + +~~~ corefile +example.org { + whoami + nsid Use The Force +} +~~~ + +And now a client with NSID support will see an OPT record with the NSID option: + +~~~ sh +% dig +nsid @localhost a whoami.example.org + +;; Got answer: +;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: 46880 +;; flags: qr aa rd; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 3 + +.... + +; OPT PSEUDOSECTION: +; EDNS: version: 0, flags:; udp: 4096 +; NSID: 55 73 65 20 54 68 65 20 46 6f 72 63 65 ("Use The Force") +;; QUESTION SECTION: +;whoami.example.org. IN A +~~~ + +## See Also + +[RFC 5001](https://tools.ietf.org/html/rfc5001) diff --git a/plugin/nsid/log_test.go b/plugin/nsid/log_test.go new file mode 100644 index 0000000..1aea379 --- /dev/null +++ b/plugin/nsid/log_test.go @@ -0,0 +1,5 @@ +package nsid + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/nsid/nsid.go b/plugin/nsid/nsid.go new file mode 100644 index 0000000..e2506b4 --- /dev/null +++ b/plugin/nsid/nsid.go @@ -0,0 +1,69 @@ +// Package nsid implements NSID protocol +package nsid + +import ( + "context" + "encoding/hex" + + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +// Nsid plugin +type Nsid struct { + Next plugin.Handler + Data string +} + +// ResponseWriter is a response writer that adds NSID response +type ResponseWriter struct { + dns.ResponseWriter + Data string + request *dns.Msg +} + +// ServeDNS implements the plugin.Handler interface. +func (n Nsid) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if option := r.IsEdns0(); option != nil { + for _, o := range option.Option { + if _, ok := o.(*dns.EDNS0_NSID); ok { + nw := &ResponseWriter{ResponseWriter: w, Data: n.Data, request: r} + return plugin.NextOrFailure(n.Name(), n.Next, ctx, nw, r) + } + } + } + return plugin.NextOrFailure(n.Name(), n.Next, ctx, w, r) +} + +// WriteMsg implements the dns.ResponseWriter interface. +func (w *ResponseWriter) WriteMsg(res *dns.Msg) error { + if w.request.IsEdns0() != nil && res.IsEdns0() == nil { + res.SetEdns0(w.request.IsEdns0().UDPSize(), true) + } + + if option := res.IsEdns0(); option != nil { + var exists bool + + for _, o := range option.Option { + if e, ok := o.(*dns.EDNS0_NSID); ok { + e.Code = dns.EDNS0NSID + e.Nsid = hex.EncodeToString([]byte(w.Data)) + exists = true + } + } + + // Append the NSID if it doesn't exist in EDNS0 options + if !exists { + option.Option = append(option.Option, &dns.EDNS0_NSID{ + Code: dns.EDNS0NSID, + Nsid: hex.EncodeToString([]byte(w.Data)), + }) + } + } + + return w.ResponseWriter.WriteMsg(res) +} + +// Name implements the Handler interface. +func (n Nsid) Name() string { return "nsid" } diff --git a/plugin/nsid/nsid_test.go b/plugin/nsid/nsid_test.go new file mode 100644 index 0000000..c6268b4 --- /dev/null +++ b/plugin/nsid/nsid_test.go @@ -0,0 +1,136 @@ +package nsid + +import ( + "context" + "encoding/hex" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/cache" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/plugin/whoami" + + "github.com/miekg/dns" +) + +func TestNsid(t *testing.T) { + em := Nsid{ + Data: "NSID", + } + + tests := []struct { + next plugin.Handler + qname string + qtype uint16 + expectedCode int + expectedReply string + expectedErr error + }{ + { + next: whoami.Whoami{}, + qname: ".", + expectedCode: dns.RcodeSuccess, + expectedReply: hex.EncodeToString([]byte("NSID")), + expectedErr: nil, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + if tc.qtype == 0 { + tc.qtype = dns.TypeA + } + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + req.Question[0].Qclass = dns.ClassINET + + req.SetEdns0(4096, false) + option := req.Extra[0].(*dns.OPT) + option.Option = append(option.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}) + em.Next = tc.next + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := em.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + if tc.expectedReply != "" { + for _, extra := range rec.Msg.Extra { + if option, ok := extra.(*dns.OPT); ok { + e := option.Option[0].(*dns.EDNS0_NSID) + if e.Nsid != tc.expectedReply { + t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, e.Nsid) + } + } + } + } + } +} + +func TestNsidCache(t *testing.T) { + em := Nsid{ + Data: "NSID", + } + c := cache.New() + + tests := []struct { + next plugin.Handler + qname string + qtype uint16 + expectedCode int + expectedReply string + expectedErr error + }{ + { + next: whoami.Whoami{}, + qname: ".", + expectedCode: dns.RcodeSuccess, + expectedReply: hex.EncodeToString([]byte("NSID")), + expectedErr: nil, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + if tc.qtype == 0 { + tc.qtype = dns.TypeA + } + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + req.Question[0].Qclass = dns.ClassINET + + req.SetEdns0(4096, false) + option := req.Extra[0].(*dns.OPT) + option.Option = append(option.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}) + em.Next = tc.next + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + c.Next = em + code, err := c.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + if tc.expectedReply != "" { + for _, extra := range rec.Msg.Extra { + if option, ok := extra.(*dns.OPT); ok { + e := option.Option[0].(*dns.EDNS0_NSID) + if e.Nsid != tc.expectedReply { + t.Errorf("Test %d: Expected answer %s, but got %s", i, tc.expectedReply, e.Nsid) + } + } + } + } + } +} diff --git a/plugin/nsid/setup.go b/plugin/nsid/setup.go new file mode 100644 index 0000000..07c6493 --- /dev/null +++ b/plugin/nsid/setup.go @@ -0,0 +1,45 @@ +package nsid + +import ( + "os" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("nsid", setup) } + +func setup(c *caddy.Controller) error { + nsid, err := nsidParse(c) + if err != nil { + return plugin.Error("nsid", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Nsid{Next: next, Data: nsid} + }) + + return nil +} + +func nsidParse(c *caddy.Controller) (string, error) { + // Use hostname as the default + nsid, err := os.Hostname() + if err != nil { + nsid = "localhost" + } + i := 0 + for c.Next() { + if i > 0 { + return nsid, plugin.ErrOnce + } + i++ + args := c.RemainingArgs() + if len(args) > 0 { + nsid = strings.Join(args, " ") + } + } + return nsid, nil +} diff --git a/plugin/nsid/setup_test.go b/plugin/nsid/setup_test.go new file mode 100644 index 0000000..15d4042 --- /dev/null +++ b/plugin/nsid/setup_test.go @@ -0,0 +1,68 @@ +package nsid + +import ( + "os" + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupNsid(t *testing.T) { + defaultNsid, err := os.Hostname() + if err != nil { + defaultNsid = "localhost" + } + tests := []struct { + input string + shouldErr bool + expectedData string + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + {`nsid`, false, defaultNsid, ""}, + {`nsid "ps0"`, false, "ps0", ""}, + {`nsid "worker1"`, false, "worker1", ""}, + {`nsid "tf 2"`, false, "tf 2", ""}, + {`nsid + nsid`, true, "", "plugin"}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + nsid, err := nsidParse(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + + if !test.shouldErr && nsid != test.expectedData { + t.Errorf("Nsid not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedData, nsid) + } + } +} diff --git a/plugin/pkg/cache/cache.go b/plugin/pkg/cache/cache.go new file mode 100644 index 0000000..6c4105e --- /dev/null +++ b/plugin/pkg/cache/cache.go @@ -0,0 +1,157 @@ +// Package cache implements a cache. The cache hold 256 shards, each shard +// holds a cache: a map with a mutex. There is no fancy expunge algorithm, it +// just randomly evicts elements when it gets full. +package cache + +import ( + "hash/fnv" + "sync" +) + +// Hash returns the FNV hash of what. +func Hash(what []byte) uint64 { + h := fnv.New64() + h.Write(what) + return h.Sum64() +} + +// Cache is cache. +type Cache struct { + shards [shardSize]*shard +} + +// shard is a cache with random eviction. +type shard struct { + items map[uint64]interface{} + size int + + sync.RWMutex +} + +// New returns a new cache. +func New(size int) *Cache { + ssize := size / shardSize + if ssize < 4 { + ssize = 4 + } + + c := &Cache{} + + // Initialize all the shards + for i := 0; i < shardSize; i++ { + c.shards[i] = newShard(ssize) + } + return c +} + +// Add adds a new element to the cache. If the element already exists it is overwritten. +// Returns true if an existing element was evicted to make room for this element. +func (c *Cache) Add(key uint64, el interface{}) bool { + shard := key & (shardSize - 1) + return c.shards[shard].Add(key, el) +} + +// Get looks up element index under key. +func (c *Cache) Get(key uint64) (interface{}, bool) { + shard := key & (shardSize - 1) + return c.shards[shard].Get(key) +} + +// Remove removes the element indexed with key. +func (c *Cache) Remove(key uint64) { + shard := key & (shardSize - 1) + c.shards[shard].Remove(key) +} + +// Len returns the number of elements in the cache. +func (c *Cache) Len() int { + l := 0 + for _, s := range &c.shards { + l += s.Len() + } + return l +} + +// Walk walks each shard in the cache. +func (c *Cache) Walk(f func(map[uint64]interface{}, uint64) bool) { + for _, s := range &c.shards { + s.Walk(f) + } +} + +// newShard returns a new shard with size. +func newShard(size int) *shard { return &shard{items: make(map[uint64]interface{}), size: size} } + +// Add adds element indexed by key into the cache. Any existing element is overwritten +// Returns true if an existing element was evicted to make room for this element. +func (s *shard) Add(key uint64, el interface{}) bool { + eviction := false + s.Lock() + if len(s.items) >= s.size { + if _, ok := s.items[key]; !ok { + for k := range s.items { + delete(s.items, k) + eviction = true + break + } + } + } + s.items[key] = el + s.Unlock() + return eviction +} + +// Remove removes the element indexed by key from the cache. +func (s *shard) Remove(key uint64) { + s.Lock() + delete(s.items, key) + s.Unlock() +} + +// Evict removes a random element from the cache. +func (s *shard) Evict() { + s.Lock() + for k := range s.items { + delete(s.items, k) + break + } + s.Unlock() +} + +// Get looks up the element indexed under key. +func (s *shard) Get(key uint64) (interface{}, bool) { + s.RLock() + el, found := s.items[key] + s.RUnlock() + return el, found +} + +// Len returns the current length of the cache. +func (s *shard) Len() int { + s.RLock() + l := len(s.items) + s.RUnlock() + return l +} + +// Walk walks the shard for each element the function f is executed while holding a write lock. +func (s *shard) Walk(f func(map[uint64]interface{}, uint64) bool) { + s.RLock() + items := make([]uint64, len(s.items)) + i := 0 + for k := range s.items { + items[i] = k + i++ + } + s.RUnlock() + for _, k := range items { + s.Lock() + ok := f(s.items, k) + s.Unlock() + if !ok { + return + } + } +} + +const shardSize = 256 diff --git a/plugin/pkg/cache/cache_test.go b/plugin/pkg/cache/cache_test.go new file mode 100644 index 0000000..e9e0a30 --- /dev/null +++ b/plugin/pkg/cache/cache_test.go @@ -0,0 +1,85 @@ +package cache + +import ( + "testing" +) + +func TestCacheAddAndGet(t *testing.T) { + const N = shardSize * 4 + c := New(N) + c.Add(1, 1) + + if _, found := c.Get(1); !found { + t.Fatal("Failed to find inserted record") + } + + for i := 0; i < N; i++ { + c.Add(uint64(i), 1) + } + for i := 0; i < N; i++ { + c.Add(uint64(i), 1) + if c.Len() != N { + t.Fatal("A item was unnecessarily evicted from the cache") + } + } +} + +func TestCacheLen(t *testing.T) { + c := New(4) + + c.Add(1, 1) + if l := c.Len(); l != 1 { + t.Fatalf("Cache size should %d, got %d", 1, l) + } + + c.Add(1, 1) + if l := c.Len(); l != 1 { + t.Fatalf("Cache size should %d, got %d", 1, l) + } + + c.Add(2, 2) + if l := c.Len(); l != 2 { + t.Fatalf("Cache size should %d, got %d", 2, l) + } +} + +func TestCacheSharding(t *testing.T) { + c := New(shardSize) + for i := 0; i < shardSize*2; i++ { + c.Add(uint64(i), 1) + } + for i, s := range c.shards { + if s.Len() == 0 { + t.Errorf("Failed to populate shard: %d", i) + } + } +} + +func TestCacheWalk(t *testing.T) { + c := New(10) + exp := make([]int, 10*2) + for i := 0; i < 10*2; i++ { + c.Add(uint64(i), 1) + exp[i] = 1 + } + got := make([]int, 10*2) + c.Walk(func(items map[uint64]interface{}, key uint64) bool { + got[key] = items[key].(int) + return true + }) + for i := range exp { + if exp[i] != got[i] { + t.Errorf("Expected %d, got %d", exp[i], got[i]) + } + } +} + +func BenchmarkCache(b *testing.B) { + b.ReportAllocs() + + c := New(4) + for n := 0; n < b.N; n++ { + c.Add(1, 1) + c.Get(1) + } +} diff --git a/plugin/pkg/cache/shard_test.go b/plugin/pkg/cache/shard_test.go new file mode 100644 index 0000000..a383130 --- /dev/null +++ b/plugin/pkg/cache/shard_test.go @@ -0,0 +1,139 @@ +package cache + +import ( + "sync" + "testing" +) + +func TestShardAddAndGet(t *testing.T) { + s := newShard(1) + s.Add(1, 1) + + if _, found := s.Get(1); !found { + t.Fatal("Failed to find inserted record") + } + + s.Add(2, 1) + if _, found := s.Get(1); found { + t.Fatal("Failed to evict record") + } + if _, found := s.Get(2); !found { + t.Fatal("Failed to find inserted record") + } +} + +func TestAddEvict(t *testing.T) { + const size = 1024 + s := newShard(size) + + for i := uint64(0); i < size; i++ { + s.Add(i, 1) + } + for i := uint64(0); i < size; i++ { + s.Add(i, 1) + if s.Len() != size { + t.Fatal("A item was unnecessarily evicted from the cache") + } + } +} + +func TestShardLen(t *testing.T) { + s := newShard(4) + + s.Add(1, 1) + if l := s.Len(); l != 1 { + t.Fatalf("Shard size should %d, got %d", 1, l) + } + + s.Add(1, 1) + if l := s.Len(); l != 1 { + t.Fatalf("Shard size should %d, got %d", 1, l) + } + + s.Add(2, 2) + if l := s.Len(); l != 2 { + t.Fatalf("Shard size should %d, got %d", 2, l) + } +} + +func TestShardEvict(t *testing.T) { + s := newShard(1) + s.Add(1, 1) + s.Add(2, 2) + // 1 should be gone + + if _, found := s.Get(1); found { + t.Fatal("Found item that should have been evicted") + } +} + +func TestShardLenEvict(t *testing.T) { + s := newShard(4) + s.Add(1, 1) + s.Add(2, 1) + s.Add(3, 1) + s.Add(4, 1) + + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } + + // This should evict one element + s.Add(5, 1) + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } + + // Make sure we don't accidentally evict an element when + // we the key is already stored. + for i := 0; i < 4; i++ { + s.Add(5, 1) + if l := s.Len(); l != 4 { + t.Fatalf("Shard size should %d, got %d", 4, l) + } + } +} + +func TestShardEvictParallel(t *testing.T) { + s := newShard(shardSize) + for i := uint64(0); i < shardSize; i++ { + s.Add(i, struct{}{}) + } + start := make(chan struct{}) + var wg sync.WaitGroup + for i := 0; i < shardSize; i++ { + wg.Add(1) + go func() { + <-start + s.Evict() + wg.Done() + }() + } + close(start) // start evicting in parallel + wg.Wait() + if s.Len() != 0 { + t.Fatalf("Failed to evict all keys in parallel: %d", s.Len()) + } +} + +func BenchmarkShard(b *testing.B) { + s := newShard(shardSize) + b.ResetTimer() + for i := 0; i < b.N; i++ { + k := uint64(i) % shardSize * 2 + s.Add(k, 1) + s.Get(k) + } +} + +func BenchmarkShardParallel(b *testing.B) { + s := newShard(shardSize) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for i := uint64(0); pb.Next(); i++ { + k := i % shardSize * 2 + s.Add(k, 1) + s.Get(k) + } + }) +} diff --git a/plugin/pkg/cidr/cidr.go b/plugin/pkg/cidr/cidr.go new file mode 100644 index 0000000..91aead9 --- /dev/null +++ b/plugin/pkg/cidr/cidr.go @@ -0,0 +1,83 @@ +// Package cidr contains functions that deal with classless reverse zones in the DNS. +package cidr + +import ( + "math" + "net" + "strings" + + "github.com/apparentlymart/go-cidr/cidr" + "github.com/miekg/dns" +) + +// Split returns a slice of non-overlapping subnets that in union equal the subnet n, +// and where each subnet falls on a reverse name segment boundary. +// for ipv4 this is any multiple of 8 bits (/8, /16, /24 or /32) +// for ipv6 this is any multiple of 4 bits +func Split(n *net.IPNet) []string { + boundary := 8 + nstr := n.String() + if strings.Contains(nstr, ":") { + boundary = 4 + } + ones, _ := n.Mask.Size() + if ones%boundary == 0 { + return []string{n.String()} + } + + mask := int(math.Ceil(float64(ones)/float64(boundary))) * boundary + networks := nets(n, mask) + cidrs := make([]string, len(networks)) + for i := range networks { + cidrs[i] = networks[i].String() + } + return cidrs +} + +// nets return a slice of prefixes with the desired mask subnetted from original network. +func nets(network *net.IPNet, newPrefixLen int) []*net.IPNet { + prefixLen, _ := network.Mask.Size() + maxSubnets := int(math.Exp2(float64(newPrefixLen)) / math.Exp2(float64(prefixLen))) + nets := []*net.IPNet{{IP: network.IP, Mask: net.CIDRMask(newPrefixLen, 8*len(network.IP))}} + + for i := 1; i < maxSubnets; i++ { + next, exceeds := cidr.NextSubnet(nets[len(nets)-1], newPrefixLen) + nets = append(nets, next) + if exceeds { + break + } + } + + return nets +} + +// Reverse return the reverse zones that are authoritative for each net in ns. +func Reverse(nets []string) []string { + rev := make([]string, len(nets)) + for i := range nets { + ip, n, _ := net.ParseCIDR(nets[i]) + r, err1 := dns.ReverseAddr(ip.String()) + if err1 != nil { + continue + } + ones, bits := n.Mask.Size() + // get the size, in bits, of each portion of hostname defined in the reverse address. (8 for IPv4, 4 for IPv6) + sizeDigit := 8 + if len(n.IP) == net.IPv6len { + sizeDigit = 4 + } + // Get the first lower octet boundary to see what encompassing zone we should be authoritative for. + mod := (bits - ones) % sizeDigit + nearest := (bits - ones) + mod + offset := 0 + var end bool + for i := 0; i < nearest/sizeDigit; i++ { + offset, end = dns.NextLabel(r, offset) + if end { + break + } + } + rev[i] = r[offset:] + } + return rev +} diff --git a/plugin/pkg/cidr/cidr_test.go b/plugin/pkg/cidr/cidr_test.go new file mode 100644 index 0000000..055a82e --- /dev/null +++ b/plugin/pkg/cidr/cidr_test.go @@ -0,0 +1,47 @@ +package cidr + +import ( + "net" + "testing" +) + +var tests = []struct { + in string + expected []string + zones []string +}{ + {"10.0.0.0/15", []string{"10.0.0.0/16", "10.1.0.0/16"}, []string{"0.10.in-addr.arpa.", "1.10.in-addr.arpa."}}, + {"10.0.0.0/16", []string{"10.0.0.0/16"}, []string{"0.10.in-addr.arpa."}}, + {"192.168.1.1/23", []string{"192.168.0.0/24", "192.168.1.0/24"}, []string{"0.168.192.in-addr.arpa.", "1.168.192.in-addr.arpa."}}, + {"10.129.60.0/22", []string{"10.129.60.0/24", "10.129.61.0/24", "10.129.62.0/24", "10.129.63.0/24"}, []string{"60.129.10.in-addr.arpa.", "61.129.10.in-addr.arpa.", "62.129.10.in-addr.arpa.", "63.129.10.in-addr.arpa."}}, + {"2001:db8::/31", []string{"2001:db8::/32", "2001:db9::/32"}, []string{"8.b.d.0.1.0.0.2.ip6.arpa.", "9.b.d.0.1.0.0.2.ip6.arpa."}}, +} + +func TestSplit(t *testing.T) { + for i, tc := range tests { + _, n, _ := net.ParseCIDR(tc.in) + nets := Split(n) + if len(nets) != len(tc.expected) { + t.Errorf("Test %d, expected %d subnets, got %d", i, len(tc.expected), len(nets)) + continue + } + for j := range nets { + if nets[j] != tc.expected[j] { + t.Errorf("Test %d, expected %s, got %s", i, tc.expected[j], nets[j]) + } + } + } +} + +func TestReverse(t *testing.T) { + for i, tc := range tests { + _, n, _ := net.ParseCIDR(tc.in) + nets := Split(n) + reverse := Reverse(nets) + for j := range reverse { + if reverse[j] != tc.zones[j] { + t.Errorf("Test %d, expected %s, got %s", i, tc.zones[j], reverse[j]) + } + } + } +} diff --git a/plugin/pkg/dnstest/multirecorder.go b/plugin/pkg/dnstest/multirecorder.go new file mode 100644 index 0000000..fe8ee03 --- /dev/null +++ b/plugin/pkg/dnstest/multirecorder.go @@ -0,0 +1,41 @@ +package dnstest + +import ( + "time" + + "github.com/miekg/dns" +) + +// MultiRecorder is a type of ResponseWriter that captures all messages written to it. +type MultiRecorder struct { + Len int + Msgs []*dns.Msg + Start time.Time + dns.ResponseWriter +} + +// NewMultiRecorder makes and returns a new MultiRecorder. +func NewMultiRecorder(w dns.ResponseWriter) *MultiRecorder { + return &MultiRecorder{ + ResponseWriter: w, + Msgs: make([]*dns.Msg, 0), + Start: time.Now(), + } +} + +// WriteMsg records the message and its length written to it and call the +// underlying ResponseWriter's WriteMsg method. +func (r *MultiRecorder) WriteMsg(res *dns.Msg) error { + r.Len += res.Len() + r.Msgs = append(r.Msgs, res) + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the length of the messages that get written to it. +func (r *MultiRecorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.Len += n + } + return n, err +} diff --git a/plugin/pkg/dnstest/multirecorder_test.go b/plugin/pkg/dnstest/multirecorder_test.go new file mode 100644 index 0000000..1299db5 --- /dev/null +++ b/plugin/pkg/dnstest/multirecorder_test.go @@ -0,0 +1,38 @@ +package dnstest + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestMultiWriteMsg(t *testing.T) { + w := &responseWriter{} + record := NewMultiRecorder(w) + + responseTestName := "testmsg.example.org." + responseTestMsg := new(dns.Msg) + responseTestMsg.SetQuestion(responseTestName, dns.TypeA) + + record.WriteMsg(responseTestMsg) + record.WriteMsg(responseTestMsg) + + if len(record.Msgs) != 2 { + t.Fatalf("Expected 2 messages to be written, but instead found %d\n", len(record.Msgs)) + } + if record.Len != responseTestMsg.Len()*2 { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", responseTestMsg.Len()*2, record.Len) + } +} + +func TestMultiWrite(t *testing.T) { + w := &responseWriter{} + record := NewRecorder(w) + responseTest := []byte("testmsg.example.org.") + + record.Write(responseTest) + record.Write(responseTest) + if record.Len != len(responseTest)*2 { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(responseTest)*2, record.Len) + } +} diff --git a/plugin/pkg/dnstest/recorder.go b/plugin/pkg/dnstest/recorder.go new file mode 100644 index 0000000..1da063e --- /dev/null +++ b/plugin/pkg/dnstest/recorder.go @@ -0,0 +1,54 @@ +// Package dnstest allows for easy testing of DNS client against a test server. +package dnstest + +import ( + "time" + + "github.com/miekg/dns" +) + +// Recorder is a type of ResponseWriter that captures +// the rcode code written to it and also the size of the message +// written in the response. A rcode code does not have +// to be written, however, in which case 0 must be assumed. +// It is best to have the constructor initialize this type +// with that default status code. +type Recorder struct { + dns.ResponseWriter + Rcode int + Len int + Msg *dns.Msg + Start time.Time +} + +// NewRecorder makes and returns a new Recorder, +// which captures the DNS rcode from the ResponseWriter +// and also the length of the response message written through it. +func NewRecorder(w dns.ResponseWriter) *Recorder { + return &Recorder{ + ResponseWriter: w, + Rcode: 0, + Msg: nil, + Start: time.Now(), + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *Recorder) WriteMsg(res *dns.Msg) error { + r.Rcode = res.Rcode + // We may get called multiple times (axfr for instance). + // Save the last message, but add the sizes. + r.Len += res.Len() + r.Msg = res + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the length of the message that gets written. +func (r *Recorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.Len += n + } + return n, err +} diff --git a/plugin/pkg/dnstest/recorder_test.go b/plugin/pkg/dnstest/recorder_test.go new file mode 100644 index 0000000..96af7b0 --- /dev/null +++ b/plugin/pkg/dnstest/recorder_test.go @@ -0,0 +1,50 @@ +package dnstest + +import ( + "testing" + + "github.com/miekg/dns" +) + +type responseWriter struct{ dns.ResponseWriter } + +func (r *responseWriter) WriteMsg(m *dns.Msg) error { return nil } +func (r *responseWriter) Write(buf []byte) (int, error) { return len(buf), nil } + +func TestNewRecorder(t *testing.T) { + w := &responseWriter{} + record := NewRecorder(w) + if record.ResponseWriter != w { + t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n") + } + if record.Rcode != dns.RcodeSuccess { + t.Fatalf("Expected recorded status to be dns.RcodeSuccess (%d) , but found %d\n ", dns.RcodeSuccess, record.Rcode) + } +} + +func TestWriteMsg(t *testing.T) { + w := &responseWriter{} + record := NewRecorder(w) + responseTestName := "testmsg.example.org." + responseTestMsg := new(dns.Msg) + responseTestMsg.SetQuestion(responseTestName, dns.TypeA) + + record.WriteMsg(responseTestMsg) + if record.Len != responseTestMsg.Len() { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", responseTestMsg.Len(), record.Len) + } + if x := record.Msg.Question[0].Name; x != responseTestName { + t.Fatalf("Expected Msg Qname to be %s , but found %s\n", responseTestName, x) + } +} + +func TestWrite(t *testing.T) { + w := &responseWriter{} + record := NewRecorder(w) + responseTest := []byte("testmsg.example.org.") + + record.Write(responseTest) + if record.Len != len(responseTest) { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(responseTest), record.Len) + } +} diff --git a/plugin/pkg/dnstest/server.go b/plugin/pkg/dnstest/server.go new file mode 100644 index 0000000..94c3906 --- /dev/null +++ b/plugin/pkg/dnstest/server.go @@ -0,0 +1,65 @@ +package dnstest + +import ( + "net" + + "github.com/coredns/coredns/plugin/pkg/reuseport" + + "github.com/miekg/dns" +) + +// A Server is an DNS server listening on a system-chosen port on the local +// loopback interface, for use in end-to-end DNS tests. +type Server struct { + Addr string // Address where the server listening. + + s1 *dns.Server // udp + s2 *dns.Server // tcp +} + +// NewServer starts and returns a new Server. The caller should call Close when +// finished, to shut it down. +func NewServer(f dns.HandlerFunc) *Server { + dns.HandleFunc(".", f) + + ch1 := make(chan bool) + ch2 := make(chan bool) + + s1 := &dns.Server{} // udp + s2 := &dns.Server{} // tcp + + for i := 0; i < 5; i++ { // 5 attempts + s2.Listener, _ = reuseport.Listen("tcp", ":0") + if s2.Listener == nil { + continue + } + + s1.PacketConn, _ = net.ListenPacket("udp", s2.Listener.Addr().String()) + if s1.PacketConn != nil { + break + } + + // perhaps UPD port is in use, try again + s2.Listener.Close() + s2.Listener = nil + } + if s2.Listener == nil { + panic("dnstest.NewServer(): failed to create new server") + } + + s1.NotifyStartedFunc = func() { close(ch1) } + s2.NotifyStartedFunc = func() { close(ch2) } + go s1.ActivateAndServe() + go s2.ActivateAndServe() + + <-ch1 + <-ch2 + + return &Server{s1: s1, s2: s2, Addr: s2.Listener.Addr().String()} +} + +// Close shuts down the server. +func (s *Server) Close() { + s.s1.Shutdown() + s.s2.Shutdown() +} diff --git a/plugin/pkg/dnstest/server_test.go b/plugin/pkg/dnstest/server_test.go new file mode 100644 index 0000000..41450e4 --- /dev/null +++ b/plugin/pkg/dnstest/server_test.go @@ -0,0 +1,37 @@ +package dnstest + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestNewServer(t *testing.T) { + s := NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + c := new(dns.Client) + c.Net = "tcp" + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeSOA) + ret, _, err := c.Exchange(m, s.Addr) + if err != nil { + t.Fatalf("Could not send message to dnstest.Server: %s", err) + } + if ret.Id != m.Id { + t.Fatalf("Msg ID's should match, expected %d, got %d", m.Id, ret.Id) + } + + c.Net = "udp" + ret, _, err = c.Exchange(m, s.Addr) + if err != nil { + t.Fatalf("Could not send message to dnstest.Server: %s", err) + } + if ret.Id != m.Id { + t.Fatalf("Msg ID's should match, expected %d, got %d", m.Id, ret.Id) + } +} diff --git a/plugin/pkg/dnsutil/cname.go b/plugin/pkg/dnsutil/cname.go new file mode 100644 index 0000000..281e032 --- /dev/null +++ b/plugin/pkg/dnsutil/cname.go @@ -0,0 +1,15 @@ +package dnsutil + +import "github.com/miekg/dns" + +// DuplicateCNAME returns true if r already exists in records. +func DuplicateCNAME(r *dns.CNAME, records []dns.RR) bool { + for _, rec := range records { + if v, ok := rec.(*dns.CNAME); ok { + if v.Target == r.Target { + return true + } + } + } + return false +} diff --git a/plugin/pkg/dnsutil/cname_test.go b/plugin/pkg/dnsutil/cname_test.go new file mode 100644 index 0000000..5fb8d30 --- /dev/null +++ b/plugin/pkg/dnsutil/cname_test.go @@ -0,0 +1,55 @@ +package dnsutil + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestDuplicateCNAME(t *testing.T) { + tests := []struct { + cname string + records []string + expected bool + }{ + { + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + []string{ + "US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534", + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + }, + true, + }, + { + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + []string{ + "US. 86400 IN NSEC 0-.us. NS SOA RRSIG NSEC DNSKEY TYPE65534", + }, + false, + }, + { + "1.0.0.192.IN-ADDR.ARPA. 3600 IN CNAME 1.0.0.0.192.IN-ADDR.ARPA.", + []string{}, + false, + }, + } + for i, test := range tests { + cnameRR, err := dns.NewRR(test.cname) + if err != nil { + t.Fatalf("Test %d, cname ('%s') error (%s)!", i, test.cname, err) + } + cname := cnameRR.(*dns.CNAME) + records := []dns.RR{} + for j, r := range test.records { + rr, err := dns.NewRR(r) + if err != nil { + t.Fatalf("Test %d, record %d ('%s') error (%s)!", i, j, r, err) + } + records = append(records, rr) + } + got := DuplicateCNAME(cname, records) + if got != test.expected { + t.Errorf("Test %d, expected '%v', got '%v' for CNAME ('%s') and RECORDS (%v)", i, test.expected, got, test.cname, test.records) + } + } +} diff --git a/plugin/pkg/dnsutil/doc.go b/plugin/pkg/dnsutil/doc.go new file mode 100644 index 0000000..75d1e8c --- /dev/null +++ b/plugin/pkg/dnsutil/doc.go @@ -0,0 +1,2 @@ +// Package dnsutil contains DNS related helper functions. +package dnsutil diff --git a/plugin/pkg/dnsutil/join.go b/plugin/pkg/dnsutil/join.go new file mode 100644 index 0000000..b3a40db --- /dev/null +++ b/plugin/pkg/dnsutil/join.go @@ -0,0 +1,17 @@ +package dnsutil + +import ( + "strings" + + "github.com/miekg/dns" +) + +// Join joins labels to form a fully qualified domain name. If the last label is +// the root label it is ignored. Not other syntax checks are performed. +func Join(labels ...string) string { + ll := len(labels) + if labels[ll-1] == "." { + return strings.Join(labels[:ll-1], ".") + "." + } + return dns.Fqdn(strings.Join(labels, ".")) +} diff --git a/plugin/pkg/dnsutil/join_test.go b/plugin/pkg/dnsutil/join_test.go new file mode 100644 index 0000000..1a50a3c --- /dev/null +++ b/plugin/pkg/dnsutil/join_test.go @@ -0,0 +1,21 @@ +package dnsutil + +import "testing" + +func TestJoin(t *testing.T) { + tests := []struct { + in []string + out string + }{ + {[]string{"bla", "bliep", "example", "org"}, "bla.bliep.example.org."}, + {[]string{"example", "."}, "example."}, + {[]string{"example", "org."}, "example.org."}, // technically we should not be called like this. + {[]string{"."}, "."}, + } + + for i, tc := range tests { + if x := Join(tc.in...); x != tc.out { + t.Errorf("Test %d, expected %s, got %s", i, tc.out, x) + } + } +} diff --git a/plugin/pkg/dnsutil/reverse.go b/plugin/pkg/dnsutil/reverse.go new file mode 100644 index 0000000..7bfd235 --- /dev/null +++ b/plugin/pkg/dnsutil/reverse.go @@ -0,0 +1,81 @@ +package dnsutil + +import ( + "net" + "strings" +) + +// ExtractAddressFromReverse turns a standard PTR reverse record name +// into an IP address. This works for ipv4 or ipv6. +// +// 54.119.58.176.in-addr.arpa. becomes 176.58.119.54. If the conversion +// fails the empty string is returned. +func ExtractAddressFromReverse(reverseName string) string { + search := "" + + f := reverse + + switch { + case strings.HasSuffix(reverseName, IP4arpa): + search = strings.TrimSuffix(reverseName, IP4arpa) + case strings.HasSuffix(reverseName, IP6arpa): + search = strings.TrimSuffix(reverseName, IP6arpa) + f = reverse6 + default: + return "" + } + + // Reverse the segments and then combine them. + return f(strings.Split(search, ".")) +} + +// IsReverse returns 0 is name is not in a reverse zone. Anything > 0 indicates +// name is in a reverse zone. The returned integer will be 1 for in-addr.arpa. (IPv4) +// and 2 for ip6.arpa. (IPv6). +func IsReverse(name string) int { + if strings.HasSuffix(name, IP4arpa) { + return 1 + } + if strings.HasSuffix(name, IP6arpa) { + return 2 + } + return 0 +} + +func reverse(slice []string) string { + for i := 0; i < len(slice)/2; i++ { + j := len(slice) - i - 1 + slice[i], slice[j] = slice[j], slice[i] + } + ip := net.ParseIP(strings.Join(slice, ".")).To4() + if ip == nil { + return "" + } + return ip.String() +} + +// reverse6 reverse the segments and combine them according to RFC3596: +// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2 +// is reversed to 2001:db8::567:89ab +func reverse6(slice []string) string { + for i := 0; i < len(slice)/2; i++ { + j := len(slice) - i - 1 + slice[i], slice[j] = slice[j], slice[i] + } + slice6 := []string{} + for i := 0; i < len(slice)/4; i++ { + slice6 = append(slice6, strings.Join(slice[i*4:i*4+4], "")) + } + ip := net.ParseIP(strings.Join(slice6, ":")).To16() + if ip == nil { + return "" + } + return ip.String() +} + +const ( + // IP4arpa is the reverse tree suffix for v4 IP addresses. + IP4arpa = ".in-addr.arpa." + // IP6arpa is the reverse tree suffix for v6 IP addresses. + IP6arpa = ".ip6.arpa." +) diff --git a/plugin/pkg/dnsutil/reverse_test.go b/plugin/pkg/dnsutil/reverse_test.go new file mode 100644 index 0000000..6fb8279 --- /dev/null +++ b/plugin/pkg/dnsutil/reverse_test.go @@ -0,0 +1,70 @@ +package dnsutil + +import ( + "testing" +) + +func TestExtractAddressFromReverse(t *testing.T) { + tests := []struct { + reverseName string + expectedAddress string + }{ + { + "54.119.58.176.in-addr.arpa.", + "176.58.119.54", + }, + { + ".58.176.in-addr.arpa.", + "", + }, + { + "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.in-addr.arpa.", + "", + }, + { + "b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", + "2001:db8::567:89ab", + }, + { + "d.0.1.0.0.2.ip6.arpa.", + "", + }, + { + "54.119.58.176.ip6.arpa.", + "", + }, + { + "NONAME", + "", + }, + { + "", + "", + }, + } + for i, test := range tests { + got := ExtractAddressFromReverse(test.reverseName) + if got != test.expectedAddress { + t.Errorf("Test %d, expected '%s', got '%s'", i, test.expectedAddress, got) + } + } +} + +func TestIsReverse(t *testing.T) { + tests := []struct { + name string + expected int + }{ + {"b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.", 2}, + {"d.0.1.0.0.2.in-addr.arpa.", 1}, + {"example.com.", 0}, + {"", 0}, + {"in-addr.arpa.example.com.", 0}, + } + for i, tc := range tests { + got := IsReverse(tc.name) + if got != tc.expected { + t.Errorf("Test %d, got %d, expected %d for %s", i, got, tc.expected, tc.name) + } + } +} diff --git a/plugin/pkg/dnsutil/ttl.go b/plugin/pkg/dnsutil/ttl.go new file mode 100644 index 0000000..c7f423a --- /dev/null +++ b/plugin/pkg/dnsutil/ttl.go @@ -0,0 +1,53 @@ +package dnsutil + +import ( + "time" + + "github.com/coredns/coredns/plugin/pkg/response" + + "github.com/miekg/dns" +) + +// MinimalTTL scans the message returns the lowest TTL found taking into the response.Type of the message. +func MinimalTTL(m *dns.Msg, mt response.Type) time.Duration { + if mt != response.NoError && mt != response.NameError && mt != response.NoData { + return MinimalDefaultTTL + } + + // No records or OPT is the only record, return a short ttl as a fail safe. + if len(m.Answer)+len(m.Ns) == 0 && + (len(m.Extra) == 0 || (len(m.Extra) == 1 && m.Extra[0].Header().Rrtype == dns.TypeOPT)) { + return MinimalDefaultTTL + } + + minTTL := MaximumDefaulTTL + for _, r := range m.Answer { + if r.Header().Ttl < uint32(minTTL.Seconds()) { + minTTL = time.Duration(r.Header().Ttl) * time.Second + } + } + for _, r := range m.Ns { + if r.Header().Ttl < uint32(minTTL.Seconds()) { + minTTL = time.Duration(r.Header().Ttl) * time.Second + } + } + + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + // OPT records use TTL field for extended rcode and flags + continue + } + if r.Header().Ttl < uint32(minTTL.Seconds()) { + minTTL = time.Duration(r.Header().Ttl) * time.Second + } + } + return minTTL +} + +const ( + // MinimalDefaultTTL is the absolute lowest TTL we use in CoreDNS. + MinimalDefaultTTL = 5 * time.Second + // MaximumDefaulTTL is the maximum TTL was use on RRsets in CoreDNS. + // TODO: rename as MaximumDefaultTTL + MaximumDefaulTTL = 1 * time.Hour +) diff --git a/plugin/pkg/dnsutil/ttl_test.go b/plugin/pkg/dnsutil/ttl_test.go new file mode 100644 index 0000000..0f49bf5 --- /dev/null +++ b/plugin/pkg/dnsutil/ttl_test.go @@ -0,0 +1,72 @@ +package dnsutil + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/response" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// See https://github.com/kubernetes/dns/issues/121, add some specific tests for those use cases. + +func TestMinimalTTL(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("z.alm.im.", dns.TypeA) + m.Ns = []dns.RR{ + test.SOA("alm.im. 1800 IN SOA ivan.ns.cloudflare.com. dns.cloudflare.com. 2025042470 10000 2400 604800 3600"), + } + + utc := time.Now().UTC() + + mt, _ := response.Typify(m, utc) + if mt != response.NoData { + t.Fatalf("Expected type to be response.NoData, got %s", mt) + } + dur := MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) + if dur != 1800*time.Second { + t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) + } + + m.Rcode = dns.RcodeNameError + mt, _ = response.Typify(m, utc) + if mt != response.NameError { + t.Fatalf("Expected type to be response.NameError, got %s", mt) + } + dur = MinimalTTL(m, mt) // minTTL on msg is 3600 (neg. ttl on SOA) + if dur != 1800*time.Second { + t.Fatalf("Expected minttl duration to be %d, got %d", 1800, dur) + } +} + +func BenchmarkMinimalTTL(b *testing.B) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + m.Ns = []dns.RR{ + test.A("a.example.org. 1800 IN A 127.0.0.53"), + test.A("b.example.org. 1900 IN A 127.0.0.53"), + test.A("c.example.org. 1600 IN A 127.0.0.53"), + test.A("d.example.org. 1100 IN A 127.0.0.53"), + test.A("e.example.org. 1000 IN A 127.0.0.53"), + } + m.Extra = []dns.RR{ + test.A("a.example.org. 1800 IN A 127.0.0.53"), + test.A("b.example.org. 1600 IN A 127.0.0.53"), + test.A("c.example.org. 1400 IN A 127.0.0.53"), + test.A("d.example.org. 1200 IN A 127.0.0.53"), + test.A("e.example.org. 1100 IN A 127.0.0.53"), + } + + utc := time.Now().UTC() + mt, _ := response.Typify(m, utc) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + dur := MinimalTTL(m, mt) + if dur != 1000*time.Second { + b.Fatalf("Wrong MinimalTTL %d, expected %d", dur, 1000*time.Second) + } + } +} diff --git a/plugin/pkg/dnsutil/zone.go b/plugin/pkg/dnsutil/zone.go new file mode 100644 index 0000000..579fef1 --- /dev/null +++ b/plugin/pkg/dnsutil/zone.go @@ -0,0 +1,20 @@ +package dnsutil + +import ( + "errors" + + "github.com/miekg/dns" +) + +// TrimZone removes the zone component from q. It returns the trimmed +// name or an error is zone is longer then qname. The trimmed name will be returned +// without a trailing dot. +func TrimZone(q string, z string) (string, error) { + zl := dns.CountLabel(z) + i, ok := dns.PrevLabel(q, zl) + if ok || i-1 < 0 { + return "", errors.New("trimzone: overshot qname: " + q + "for zone " + z) + } + // This includes the '.', remove on return + return q[:i-1], nil +} diff --git a/plugin/pkg/dnsutil/zone_test.go b/plugin/pkg/dnsutil/zone_test.go new file mode 100644 index 0000000..81cd1ad --- /dev/null +++ b/plugin/pkg/dnsutil/zone_test.go @@ -0,0 +1,39 @@ +package dnsutil + +import ( + "errors" + "testing" + + "github.com/miekg/dns" +) + +func TestTrimZone(t *testing.T) { + tests := []struct { + qname string + zone string + expected string + err error + }{ + {"a.example.org", "example.org", "a", nil}, + {"a.b.example.org", "example.org", "a.b", nil}, + {"b.", ".", "b", nil}, + {"example.org", "example.org", "", errors.New("should err")}, + {"org", "example.org", "", errors.New("should err")}, + } + + for i, tc := range tests { + got, err := TrimZone(dns.Fqdn(tc.qname), dns.Fqdn(tc.zone)) + if tc.err != nil && err == nil { + t.Errorf("Test %d, expected error got nil", i) + continue + } + if tc.err == nil && err != nil { + t.Errorf("Test %d, expected no error got %v", i, err) + continue + } + if got != tc.expected { + t.Errorf("Test %d, expected %s, got %s", i, tc.expected, got) + continue + } + } +} diff --git a/plugin/pkg/doh/doh.go b/plugin/pkg/doh/doh.go new file mode 100644 index 0000000..faddfc8 --- /dev/null +++ b/plugin/pkg/doh/doh.go @@ -0,0 +1,133 @@ +package doh + +import ( + "bytes" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + + "github.com/miekg/dns" +) + +// MimeType is the DoH mimetype that should be used. +const MimeType = "application/dns-message" + +// Path is the URL path that should be used. +const Path = "/dns-query" + +// NewRequest returns a new DoH request given a HTTP method, URL and dns.Msg. +// +// The URL should not have a path, so please exclude /dns-query. The URL will +// be prefixed with https:// by default, unless it's already prefixed with +// either http:// or https://. +func NewRequest(method, url string, m *dns.Msg) (*http.Request, error) { + buf, err := m.Pack() + if err != nil { + return nil, err + } + + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = fmt.Sprintf("https://%s", url) + } + + switch method { + case http.MethodGet: + b64 := base64.RawURLEncoding.EncodeToString(buf) + + req, err := http.NewRequest( + http.MethodGet, + fmt.Sprintf("%s%s?dns=%s", url, Path, b64), + nil, + ) + if err != nil { + return req, err + } + + req.Header.Set("content-type", MimeType) + req.Header.Set("accept", MimeType) + return req, nil + + case http.MethodPost: + req, err := http.NewRequest( + http.MethodPost, + fmt.Sprintf("%s%s?bla=foo:443", url, Path), + bytes.NewReader(buf), + ) + if err != nil { + return req, err + } + + req.Header.Set("content-type", MimeType) + req.Header.Set("accept", MimeType) + return req, nil + + default: + return nil, fmt.Errorf("method not allowed: %s", method) + } +} + +// ResponseToMsg converts a http.Response to a dns message. +func ResponseToMsg(resp *http.Response) (*dns.Msg, error) { + defer resp.Body.Close() + + return toMsg(resp.Body) +} + +// RequestToMsg converts a http.Request to a dns message. +func RequestToMsg(req *http.Request) (*dns.Msg, error) { + switch req.Method { + case http.MethodGet: + return requestToMsgGet(req) + + case http.MethodPost: + return requestToMsgPost(req) + + default: + return nil, fmt.Errorf("method not allowed: %s", req.Method) + } +} + +// requestToMsgPost extracts the dns message from the request body. +func requestToMsgPost(req *http.Request) (*dns.Msg, error) { + defer req.Body.Close() + return toMsg(req.Body) +} + +// requestToMsgGet extract the dns message from the GET request. +func requestToMsgGet(req *http.Request) (*dns.Msg, error) { + values := req.URL.Query() + b64, ok := values["dns"] + if !ok { + return nil, fmt.Errorf("no 'dns' query parameter found") + } + if len(b64) != 1 { + return nil, fmt.Errorf("multiple 'dns' query values found") + } + return base64ToMsg(b64[0]) +} + +func toMsg(r io.ReadCloser) (*dns.Msg, error) { + buf, err := io.ReadAll(http.MaxBytesReader(nil, r, 65536)) + if err != nil { + return nil, err + } + m := new(dns.Msg) + err = m.Unpack(buf) + return m, err +} + +func base64ToMsg(b64 string) (*dns.Msg, error) { + buf, err := b64Enc.DecodeString(b64) + if err != nil { + return nil, err + } + + m := new(dns.Msg) + err = m.Unpack(buf) + + return m, err +} + +var b64Enc = base64.RawURLEncoding diff --git a/plugin/pkg/doh/doh_test.go b/plugin/pkg/doh/doh_test.go new file mode 100644 index 0000000..047d013 --- /dev/null +++ b/plugin/pkg/doh/doh_test.go @@ -0,0 +1,46 @@ +package doh + +import ( + "net/http" + "testing" + + "github.com/miekg/dns" +) + +func TestDoH(t *testing.T) { + tests := map[string]struct { + method string + url string + }{ + "POST request over HTTPS": {method: http.MethodPost, url: "https://example.org:443"}, + "POST request over HTTP": {method: http.MethodPost, url: "http://example.org:443"}, + "POST request without protocol": {method: http.MethodPost, url: "example.org:443"}, + "GET request over HTTPS": {method: http.MethodGet, url: "https://example.org:443"}, + "GET request over HTTP": {method: http.MethodGet, url: "http://example.org"}, + "GET request without protocol": {method: http.MethodGet, url: "example.org:443"}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeDNSKEY) + + req, err := NewRequest(test.method, test.url, m) + if err != nil { + t.Errorf("Failure to make request: %s", err) + } + + m, err = RequestToMsg(req) + if err != nil { + t.Fatalf("Failure to get message from request: %s", err) + } + + if x := m.Question[0].Name; x != "example.org." { + t.Errorf("Qname expected %s, got %s", "example.org.", x) + } + if x := m.Question[0].Qtype; x != dns.TypeDNSKEY { + t.Errorf("Qname expected %d, got %d", x, dns.TypeDNSKEY) + } + }) + } +} diff --git a/plugin/pkg/durations/durations.go b/plugin/pkg/durations/durations.go new file mode 100644 index 0000000..37771e7 --- /dev/null +++ b/plugin/pkg/durations/durations.go @@ -0,0 +1,26 @@ +package durations + +import ( + "fmt" + "strconv" + "time" +) + +// NewDurationFromArg returns a time.Duration from a configuration argument +// (string) which has come from the Corefile. The argument has some basic +// validation applied before returning a time.Duration. If the argument has no +// time unit specified and is numeric the argument will be treated as seconds +// rather than GO's default of nanoseconds. +func NewDurationFromArg(arg string) (time.Duration, error) { + _, err := strconv.Atoi(arg) + if err == nil { + arg = arg + "s" + } + + d, err := time.ParseDuration(arg) + if err != nil { + return 0, fmt.Errorf("failed to parse duration '%s'", arg) + } + + return d, nil +} diff --git a/plugin/pkg/durations/durations_test.go b/plugin/pkg/durations/durations_test.go new file mode 100644 index 0000000..12008a7 --- /dev/null +++ b/plugin/pkg/durations/durations_test.go @@ -0,0 +1,51 @@ +package durations + +import ( + "testing" + "time" +) + +func TestNewDurationFromArg(t *testing.T) { + tests := []struct { + name string + arg string + wantErr bool + want time.Duration + }{ + { + name: "valid GO duration - seconds", + arg: "30s", + want: 30 * time.Second, + }, + { + name: "valid GO duration - minutes", + arg: "2m", + want: 2 * time.Minute, + }, + { + name: "number - fallback to seconds", + arg: "30", + want: 30 * time.Second, + }, + { + name: "invalid duration", + arg: "twenty seconds", + wantErr: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := NewDurationFromArg(test.arg) + if test.wantErr && err == nil { + t.Error("error was expected") + } + if !test.wantErr && err != nil { + t.Error("error was not expected") + } + + if test.want != actual { + t.Errorf("expected '%v' got '%v'", test.want, actual) + } + }) + } +} diff --git a/plugin/pkg/edns/edns.go b/plugin/pkg/edns/edns.go new file mode 100644 index 0000000..cd86399 --- /dev/null +++ b/plugin/pkg/edns/edns.go @@ -0,0 +1,71 @@ +// Package edns provides function useful for adding/inspecting OPT records to/in messages. +package edns + +import ( + "errors" + "sync" + + "github.com/miekg/dns" +) + +var sup = &supported{m: make(map[uint16]struct{})} + +type supported struct { + m map[uint16]struct{} + sync.RWMutex +} + +// SetSupportedOption adds a new supported option the set of EDNS0 options that we support. Plugins typically call +// this in their setup code to signal support for a new option. +// By default we support: +// dns.EDNS0NSID, dns.EDNS0EXPIRE, dns.EDNS0COOKIE, dns.EDNS0TCPKEEPALIVE, dns.EDNS0PADDING. These +// values are not in this map and checked directly in the server. +func SetSupportedOption(option uint16) { + sup.Lock() + sup.m[option] = struct{}{} + sup.Unlock() +} + +// SupportedOption returns true if the option code is supported as an extra EDNS0 option. +func SupportedOption(option uint16) bool { + sup.RLock() + _, ok := sup.m[option] + sup.RUnlock() + return ok +} + +// Version checks the EDNS version in the request. If error +// is nil everything is OK and we can invoke the plugin. If non-nil, the +// returned Msg is valid to be returned to the client (and should). +func Version(req *dns.Msg) (*dns.Msg, error) { + opt := req.IsEdns0() + if opt == nil { + return nil, nil + } + if opt.Version() == 0 { + return nil, nil + } + m := new(dns.Msg) + m.SetReply(req) + + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetVersion(0) + m.Rcode = dns.RcodeBadVers + o.SetExtendedRcode(dns.RcodeBadVers) + m.Extra = []dns.RR{o} + + return m, errors.New("EDNS0 BADVERS") +} + +// Size returns a normalized size based on proto. +func Size(proto string, size uint16) uint16 { + if proto == "tcp" { + return dns.MaxMsgSize + } + if size < dns.MinMsgSize { + return dns.MinMsgSize + } + return size +} diff --git a/plugin/pkg/edns/edns_test.go b/plugin/pkg/edns/edns_test.go new file mode 100644 index 0000000..1976779 --- /dev/null +++ b/plugin/pkg/edns/edns_test.go @@ -0,0 +1,49 @@ +package edns + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestVersion(t *testing.T) { + m := ednsMsg() + m.Extra[0].(*dns.OPT).SetVersion(2) + + r, err := Version(m) + if err == nil { + t.Errorf("Expected wrong version, but got OK") + } + if r.Question == nil { + t.Errorf("Expected question section, but got nil") + } + if r.Rcode != dns.RcodeBadVers { + t.Errorf("Expected Rcode to be of BADVER (16), but got %d", r.Rcode) + } + if r.Extra == nil { + t.Errorf("Expected OPT section, but got nil") + } +} + +func TestVersionNoEdns(t *testing.T) { + m := ednsMsg() + m.Extra = nil + + r, err := Version(m) + if err != nil { + t.Errorf("Expected no error, but got one: %s", err) + } + if r != nil { + t.Errorf("Expected nil since not an EDNS0 request, but did not got nil") + } +} + +func ednsMsg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + m.Extra = append(m.Extra, o) + return m +} diff --git a/plugin/pkg/expression/expression.go b/plugin/pkg/expression/expression.go new file mode 100644 index 0000000..dad38fe --- /dev/null +++ b/plugin/pkg/expression/expression.go @@ -0,0 +1,47 @@ +package expression + +import ( + "context" + "errors" + "net" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/request" +) + +// DefaultEnv returns the default set of custom state variables and functions available to for use in expression evaluation. +func DefaultEnv(ctx context.Context, state *request.Request) map[string]interface{} { + return map[string]interface{}{ + "incidr": func(ipStr, cidrStr string) (bool, error) { + ip := net.ParseIP(ipStr) + if ip == nil { + return false, errors.New("first argument is not an IP address") + } + _, cidr, err := net.ParseCIDR(cidrStr) + if err != nil { + return false, err + } + return cidr.Contains(ip), nil + }, + "metadata": func(label string) string { + f := metadata.ValueFunc(ctx, label) + if f == nil { + return "" + } + return f() + }, + "type": state.Type, + "name": state.Name, + "class": state.Class, + "proto": state.Proto, + "size": state.Len, + "client_ip": state.IP, + "port": state.Port, + "id": func() int { return int(state.Req.Id) }, + "opcode": func() int { return state.Req.Opcode }, + "do": state.Do, + "bufsize": state.Size, + "server_ip": state.LocalIP, + "server_port": state.LocalPort, + } +} diff --git a/plugin/pkg/expression/expression_test.go b/plugin/pkg/expression/expression_test.go new file mode 100644 index 0000000..b39c679 --- /dev/null +++ b/plugin/pkg/expression/expression_test.go @@ -0,0 +1,73 @@ +package expression + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/request" +) + +func TestInCidr(t *testing.T) { + incidr := DefaultEnv(context.Background(), &request.Request{})["incidr"] + + cases := []struct { + ip string + cidr string + expected bool + shouldErr bool + }{ + // positive + {ip: "1.2.3.4", cidr: "1.2.0.0/16", expected: true, shouldErr: false}, + {ip: "10.2.3.4", cidr: "1.2.0.0/16", expected: false, shouldErr: false}, + {ip: "1:2::3:4", cidr: "1:2::/64", expected: true, shouldErr: false}, + {ip: "A:2::3:4", cidr: "1:2::/64", expected: false, shouldErr: false}, + // negative + {ip: "1.2.3.4", cidr: "invalid", shouldErr: true}, + {ip: "invalid", cidr: "1.2.0.0/16", shouldErr: true}, + } + + for i, c := range cases { + r, err := incidr.(func(string, string) (bool, error))(c.ip, c.cidr) + if err != nil && !c.shouldErr { + t.Errorf("Test %d: unexpected error %v", i, err) + continue + } + if err == nil && c.shouldErr { + t.Errorf("Test %d: expected error", i) + continue + } + if c.shouldErr { + continue + } + if r != c.expected { + t.Errorf("Test %d: expected %v", i, c.expected) + continue + } + } +} + +func TestMetadata(t *testing.T) { + ctx := metadata.ContextWithMetadata(context.Background()) + metadata.SetValueFunc(ctx, "test/metadata", func() string { + return "success" + }) + f := DefaultEnv(ctx, &request.Request{})["metadata"] + + cases := []struct { + label string + expected string + shouldErr bool + }{ + {label: "test/metadata", expected: "success"}, + {label: "test/nonexistent", expected: ""}, + } + + for i, c := range cases { + r := f.(func(string) string)(c.label) + if r != c.expected { + t.Errorf("Test %d: expected %v", i, c.expected) + continue + } + } +} diff --git a/plugin/pkg/fall/fall.go b/plugin/pkg/fall/fall.go new file mode 100644 index 0000000..898c8db --- /dev/null +++ b/plugin/pkg/fall/fall.go @@ -0,0 +1,70 @@ +// Package fall handles the fallthrough logic used in plugins that support it. Be careful when including this +// functionality in your plugin. Why? In the DNS only 1 source is authoritative for a set of names. Fallthrough +// breaks this convention by allowing a plugin to query multiple sources, depending on the replies it got sofar. +// +// This may cause issues in downstream caches, where different answers for the same query can potentially confuse clients. +// On the other hand this is a powerful feature that can aid in migration or other edge cases. +// +// The take away: be mindful of this and don't blindly assume it's a good feature to have in your plugin. +// +// See https://github.com/coredns/coredns/issues/2723 for some discussion on this, which includes this quote: +// +// TL;DR: `fallthrough` is indeed risky and hackish, but still a good feature of CoreDNS as it allows to quickly answer boring edge cases. +package fall + +import ( + "github.com/coredns/coredns/plugin" +) + +// F can be nil to allow for no fallthrough, empty allow all zones to fallthrough or +// contain a zone list that is checked. +type F struct { + Zones []string +} + +// Through will check if we should fallthrough for qname. Note that we've named the +// variable in each plugin "Fall", so this then reads Fall.Through(). +func (f F) Through(qname string) bool { + return plugin.Zones(f.Zones).Matches(qname) != "" +} + +// setZones will set zones in f. +func (f *F) setZones(zones []string) { + z := []string{} + for i := range zones { + z = append(z, plugin.Host(zones[i]).NormalizeExact()...) + } + f.Zones = z +} + +// SetZonesFromArgs sets zones in f to the passed value or to "." if the slice is empty. +func (f *F) SetZonesFromArgs(zones []string) { + if len(zones) == 0 { + f.setZones(Root.Zones) + return + } + f.setZones(zones) +} + +// Equal returns true if f and g are equal. +func (f *F) Equal(g F) bool { + if len(f.Zones) != len(g.Zones) { + return false + } + for i := range f.Zones { + if f.Zones[i] != g.Zones[i] { + return false + } + } + return true +} + +// Zero returns a zero valued F. +var Zero = func() F { + return F{[]string{}} +}() + +// Root returns F set to only ".". +var Root = func() F { + return F{[]string{"."}} +}() diff --git a/plugin/pkg/fall/fall_test.go b/plugin/pkg/fall/fall_test.go new file mode 100644 index 0000000..26cfbc2 --- /dev/null +++ b/plugin/pkg/fall/fall_test.go @@ -0,0 +1,65 @@ +package fall + +import "testing" + +func TestEqual(t *testing.T) { + var z F + f := F{Zones: []string{"example.com."}} + g := F{Zones: []string{"example.net."}} + h := F{Zones: []string{"example.com."}} + + if !f.Equal(h) { + t.Errorf("%v should equal %v", f, h) + } + + if z.Equal(f) { + t.Errorf("%v should not be equal to %v", z, f) + } + + if f.Equal(g) { + t.Errorf("%v should not be equal to %v", f, g) + } +} + +func TestZero(t *testing.T) { + var f F + if !f.Equal(Zero) { + t.Errorf("F should be zero") + } +} + +func TestSetZonesFromArgs(t *testing.T) { + var f F + f.SetZonesFromArgs([]string{}) + if !f.Equal(Root) { + t.Errorf("F should have the root zone") + } + + f.SetZonesFromArgs([]string{"example.com", "example.net."}) + expected := F{Zones: []string{"example.com.", "example.net."}} + if !f.Equal(expected) { + t.Errorf("F should be %v but is %v", expected, f) + } +} + +func TestFallthrough(t *testing.T) { + var fall F + if fall.Through("foo.com.") { + t.Errorf("Expected false, got true for zero fallthrough") + } + + fall.SetZonesFromArgs([]string{}) + if !fall.Through("foo.net.") { + t.Errorf("Expected true, got false for all zone fallthrough") + } + + fall.SetZonesFromArgs([]string{"foo.com", "bar.com"}) + + if fall.Through("foo.net.") { + t.Errorf("Expected false, got true for non-matching fallthrough zone") + } + + if !fall.Through("bar.com.") { + t.Errorf("Expected true, got false for matching fallthrough zone") + } +} diff --git a/plugin/pkg/fuzz/do.go b/plugin/pkg/fuzz/do.go new file mode 100644 index 0000000..054c429 --- /dev/null +++ b/plugin/pkg/fuzz/do.go @@ -0,0 +1,31 @@ +// Package fuzz contains functions that enable fuzzing of plugins. +package fuzz + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// Do will fuzz p - used by gofuzz. See Makefile.fuzz for comments and context. +func Do(p plugin.Handler, data []byte) int { + ctx := context.TODO() + r := new(dns.Msg) + if err := r.Unpack(data); err != nil { + return 0 // plugin will never be called when this happens. + } + // If the data unpack into a dns msg, but does not have a proper question section discard it. + // The server parts make sure this is true before calling the plugins; mimic this behavior. + if len(r.Question) == 0 { + return 0 + } + + if _, err := p.ServeDNS(ctx, &test.ResponseWriter{}, r); err != nil { + return 1 + } + + return 0 +} diff --git a/plugin/pkg/log/listener.go b/plugin/pkg/log/listener.go new file mode 100644 index 0000000..2dfe815 --- /dev/null +++ b/plugin/pkg/log/listener.go @@ -0,0 +1,141 @@ +package log + +import ( + "sync" +) + +// Listener listens for all log prints of plugin loggers aka loggers with plugin name. +// When a plugin logger gets called, it should first call the same method in the Listener object. +// A usage example is, the external plugin k8s_event will replicate log prints to Kubernetes events. +type Listener interface { + Name() string + Debug(plugin string, v ...interface{}) + Debugf(plugin string, format string, v ...interface{}) + Info(plugin string, v ...interface{}) + Infof(plugin string, format string, v ...interface{}) + Warning(plugin string, v ...interface{}) + Warningf(plugin string, format string, v ...interface{}) + Error(plugin string, v ...interface{}) + Errorf(plugin string, format string, v ...interface{}) + Fatal(plugin string, v ...interface{}) + Fatalf(plugin string, format string, v ...interface{}) +} + +type listeners struct { + listeners []Listener + sync.RWMutex +} + +var ls *listeners + +func init() { + ls = &listeners{} + ls.listeners = make([]Listener, 0) +} + +// RegisterListener register a listener object. +func RegisterListener(new Listener) error { + ls.Lock() + defer ls.Unlock() + for k, l := range ls.listeners { + if l.Name() == new.Name() { + ls.listeners[k] = new + return nil + } + } + ls.listeners = append(ls.listeners, new) + return nil +} + +// DeregisterListener deregister a listener object. +func DeregisterListener(old Listener) error { + ls.Lock() + defer ls.Unlock() + for k, l := range ls.listeners { + if l.Name() == old.Name() { + ls.listeners = append(ls.listeners[:k], ls.listeners[k+1:]...) + return nil + } + } + return nil +} + +func (ls *listeners) debug(plugin string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Debug(plugin, v...) + } + ls.RUnlock() +} + +func (ls *listeners) debugf(plugin string, format string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Debugf(plugin, format, v...) + } + ls.RUnlock() +} + +func (ls *listeners) info(plugin string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Info(plugin, v...) + } + ls.RUnlock() +} + +func (ls *listeners) infof(plugin string, format string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Infof(plugin, format, v...) + } + ls.RUnlock() +} + +func (ls *listeners) warning(plugin string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Warning(plugin, v...) + } + ls.RUnlock() +} + +func (ls *listeners) warningf(plugin string, format string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Warningf(plugin, format, v...) + } + ls.RUnlock() +} + +func (ls *listeners) error(plugin string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Error(plugin, v...) + } + ls.RUnlock() +} + +func (ls *listeners) errorf(plugin string, format string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Errorf(plugin, format, v...) + } + ls.RUnlock() +} + +func (ls *listeners) fatal(plugin string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Fatal(plugin, v...) + } + ls.RUnlock() +} + +func (ls *listeners) fatalf(plugin string, format string, v ...interface{}) { + ls.RLock() + for _, l := range ls.listeners { + l.Fatalf(plugin, format, v...) + } + ls.RUnlock() +} diff --git a/plugin/pkg/log/listener_test.go b/plugin/pkg/log/listener_test.go new file mode 100644 index 0000000..452d2b3 --- /dev/null +++ b/plugin/pkg/log/listener_test.go @@ -0,0 +1,120 @@ +package log + +import ( + "bytes" + golog "log" + "strings" + "testing" +) + +func TestRegisterAndDeregisterListener(t *testing.T) { + for _, name := range []string{"listener1", "listener2", "listener1"} { + err := RegisterListener(NewMockListener(name)) + if err != nil { + t.Errorf("RegisterListener Error %s", err) + } + } + if len(ls.listeners) != 2 { + t.Errorf("Expected number of listeners to be %d, got %d", 2, len(ls.listeners)) + } + for _, name := range []string{"listener1", "listener2"} { + err := DeregisterListener(NewMockListener(name)) + if err != nil { + t.Errorf("DeregsiterListener Error %s", err) + } + } + if len(ls.listeners) != 0 { + t.Errorf("Expected number of listeners to be %d, got %d", 0, len(ls.listeners)) + } +} + +func TestSingleListenerMock(t *testing.T) { + listener1Name := "listener1" + listener1Output := info + listener1Name + " mocked info" + testListenersCalled(t, []string{listener1Name}, []string{listener1Output}) +} + +func TestMultipleListenerMock(t *testing.T) { + listener1Name := "listener1" + listener1Output := info + listener1Name + " mocked info" + listener2Name := "listener2" + listener2Output := info + listener2Name + " mocked info" + testListenersCalled(t, []string{listener1Name, listener2Name}, []string{listener1Output, listener2Output}) +} + +func testListenersCalled(t *testing.T, listenerNames []string, outputs []string) { + for _, name := range listenerNames { + err := RegisterListener(NewMockListener(name)) + if err != nil { + t.Errorf("RegisterListener Error %s", err) + } + } + var f bytes.Buffer + const ts = "test" + golog.SetOutput(&f) + lg := NewWithPlugin("testplugin") + lg.Info(ts) + for _, str := range outputs { + if x := f.String(); !strings.Contains(x, str) { + t.Errorf("Expected log to contain %s, got %s", str, x) + } + } + for _, name := range listenerNames { + err := DeregisterListener(NewMockListener(name)) + if err != nil { + t.Errorf("DeregsiterListener Error %s", err) + } + } +} + +type mockListener struct { + name string +} + +func NewMockListener(name string) *mockListener { + return &mockListener{name: name} +} + +func (l *mockListener) Name() string { + return l.name +} + +func (l *mockListener) Debug(plugin string, v ...interface{}) { + log(debug, l.name+" mocked debug") +} + +func (l *mockListener) Debugf(plugin string, format string, v ...interface{}) { + log(debug, l.name+" mocked debug") +} + +func (l *mockListener) Info(plugin string, v ...interface{}) { + log(info, l.name+" mocked info") +} + +func (l *mockListener) Infof(plugin string, format string, v ...interface{}) { + log(info, l.name+" mocked info") +} + +func (l *mockListener) Warning(plugin string, v ...interface{}) { + log(warning, l.name+" mocked warning") +} + +func (l *mockListener) Warningf(plugin string, format string, v ...interface{}) { + log(warning, l.name+" mocked warning") +} + +func (l *mockListener) Error(plugin string, v ...interface{}) { + log(err, l.name+" mocked error") +} + +func (l *mockListener) Errorf(plugin string, format string, v ...interface{}) { + log(err, l.name+" mocked error") +} + +func (l *mockListener) Fatal(plugin string, v ...interface{}) { + log(fatal, l.name+" mocked fatal") +} + +func (l *mockListener) Fatalf(plugin string, format string, v ...interface{}) { + log(fatal, l.name+" mocked fatal") +} diff --git a/plugin/pkg/log/log.go b/plugin/pkg/log/log.go new file mode 100644 index 0000000..0589a34 --- /dev/null +++ b/plugin/pkg/log/log.go @@ -0,0 +1,113 @@ +// Package log implements a small wrapper around the std lib log package. It +// implements log levels by prefixing the logs with [INFO], [DEBUG], [WARNING] +// or [ERROR]. Debug logging is available and enabled if the *debug* plugin is +// used. +// +// log.Info("this is some logging"), will log on the Info level. +// +// log.Debug("this is debug output"), will log in the Debug level, etc. +package log + +import ( + "fmt" + "io" + golog "log" + "os" + "sync" +) + +// D controls whether we should output debug logs. If true, we do, once set +// it can not be unset. +var D = &d{} + +type d struct { + on bool + sync.RWMutex +} + +// Set enables debug logging. +func (d *d) Set() { + d.Lock() + d.on = true + d.Unlock() +} + +// Clear disables debug logging. +func (d *d) Clear() { + d.Lock() + d.on = false + d.Unlock() +} + +// Value returns if debug logging is enabled. +func (d *d) Value() bool { + d.RLock() + b := d.on + d.RUnlock() + return b +} + +// logf calls log.Printf prefixed with level. +func logf(level, format string, v ...interface{}) { + golog.Print(level, fmt.Sprintf(format, v...)) +} + +// log calls log.Print prefixed with level. +func log(level string, v ...interface{}) { + golog.Print(level, fmt.Sprint(v...)) +} + +// Debug is equivalent to log.Print(), but prefixed with "[DEBUG] ". It only outputs something +// if D is true. +func Debug(v ...interface{}) { + if !D.Value() { + return + } + log(debug, v...) +} + +// Debugf is equivalent to log.Printf(), but prefixed with "[DEBUG] ". It only outputs something +// if D is true. +func Debugf(format string, v ...interface{}) { + if !D.Value() { + return + } + logf(debug, format, v...) +} + +// Info is equivalent to log.Print, but prefixed with "[INFO] ". +func Info(v ...interface{}) { log(info, v...) } + +// Infof is equivalent to log.Printf, but prefixed with "[INFO] ". +func Infof(format string, v ...interface{}) { logf(info, format, v...) } + +// Warning is equivalent to log.Print, but prefixed with "[WARNING] ". +func Warning(v ...interface{}) { log(warning, v...) } + +// Warningf is equivalent to log.Printf, but prefixed with "[WARNING] ". +func Warningf(format string, v ...interface{}) { logf(warning, format, v...) } + +// Error is equivalent to log.Print, but prefixed with "[ERROR] ". +func Error(v ...interface{}) { log(err, v...) } + +// Errorf is equivalent to log.Printf, but prefixed with "[ERROR] ". +func Errorf(format string, v ...interface{}) { logf(err, format, v...) } + +// Fatal is equivalent to log.Print, but prefixed with "[FATAL] ", and calling +// os.Exit(1). +func Fatal(v ...interface{}) { log(fatal, v...); os.Exit(1) } + +// Fatalf is equivalent to log.Printf, but prefixed with "[FATAL] ", and calling +// os.Exit(1) +func Fatalf(format string, v ...interface{}) { logf(fatal, format, v...); os.Exit(1) } + +// Discard sets the log output to /dev/null. +func Discard() { golog.SetOutput(io.Discard) } + +const ( + debug = "[DEBUG] " + err = "[ERROR] " + fatal = "[FATAL] " + info = "[INFO] " + warning = "[WARNING] " +) diff --git a/plugin/pkg/log/log_test.go b/plugin/pkg/log/log_test.go new file mode 100644 index 0000000..32c1d39 --- /dev/null +++ b/plugin/pkg/log/log_test.go @@ -0,0 +1,72 @@ +package log + +import ( + "bytes" + golog "log" + "strings" + "testing" +) + +func TestDebug(t *testing.T) { + var f bytes.Buffer + golog.SetOutput(&f) + + // D == false + Debug("debug") + if x := f.String(); x != "" { + t.Errorf("Expected no debug logs, got %s", x) + } + f.Reset() + + D.Set() + Debug("debug") + if x := f.String(); !strings.Contains(x, debug+"debug") { + t.Errorf("Expected debug log to be %s, got %s", debug+"debug", x) + } + f.Reset() + + D.Clear() + Debug("debug") + if x := f.String(); x != "" { + t.Errorf("Expected no debug logs, got %s", x) + } +} + +func TestDebugx(t *testing.T) { + var f bytes.Buffer + golog.SetOutput(&f) + + D.Set() + + Debugf("%s", "debug") + if x := f.String(); !strings.Contains(x, debug+"debug") { + t.Errorf("Expected debug log to be %s, got %s", debug+"debug", x) + } + f.Reset() + + Debug("debug") + if x := f.String(); !strings.Contains(x, debug+"debug") { + t.Errorf("Expected debug log to be %s, got %s", debug+"debug", x) + } +} + +func TestLevels(t *testing.T) { + var f bytes.Buffer + const ts = "test" + golog.SetOutput(&f) + + Info(ts) + if x := f.String(); !strings.Contains(x, info+ts) { + t.Errorf("Expected log to be %s, got %s", info+ts, x) + } + f.Reset() + Warning(ts) + if x := f.String(); !strings.Contains(x, warning+ts) { + t.Errorf("Expected log to be %s, got %s", warning+ts, x) + } + f.Reset() + Error(ts) + if x := f.String(); !strings.Contains(x, err+ts) { + t.Errorf("Expected log to be %s, got %s", err+ts, x) + } +} diff --git a/plugin/pkg/log/plugin.go b/plugin/pkg/log/plugin.go new file mode 100644 index 0000000..1be79f1 --- /dev/null +++ b/plugin/pkg/log/plugin.go @@ -0,0 +1,91 @@ +package log + +import ( + "fmt" + "os" +) + +// P is a logger that includes the plugin doing the logging. +type P struct { + plugin string +} + +// NewWithPlugin returns a logger that includes "plugin/name: " in the log message. +// I.e [INFO] plugin/<name>: message. +func NewWithPlugin(name string) P { return P{"plugin/" + name + ": "} } + +func (p P) logf(level, format string, v ...interface{}) { + log(level, p.plugin, fmt.Sprintf(format, v...)) +} + +func (p P) log(level string, v ...interface{}) { + log(level+p.plugin, v...) +} + +// Debug logs as log.Debug. +func (p P) Debug(v ...interface{}) { + if !D.Value() { + return + } + ls.debug(p.plugin, v...) + p.log(debug, v...) +} + +// Debugf logs as log.Debugf. +func (p P) Debugf(format string, v ...interface{}) { + if !D.Value() { + return + } + ls.debugf(p.plugin, format, v...) + p.logf(debug, format, v...) +} + +// Info logs as log.Info. +func (p P) Info(v ...interface{}) { + ls.info(p.plugin, v...) + p.log(info, v...) +} + +// Infof logs as log.Infof. +func (p P) Infof(format string, v ...interface{}) { + ls.infof(p.plugin, format, v...) + p.logf(info, format, v...) +} + +// Warning logs as log.Warning. +func (p P) Warning(v ...interface{}) { + ls.warning(p.plugin, v...) + p.log(warning, v...) +} + +// Warningf logs as log.Warningf. +func (p P) Warningf(format string, v ...interface{}) { + ls.warningf(p.plugin, format, v...) + p.logf(warning, format, v...) +} + +// Error logs as log.Error. +func (p P) Error(v ...interface{}) { + ls.error(p.plugin, v...) + p.log(err, v...) +} + +// Errorf logs as log.Errorf. +func (p P) Errorf(format string, v ...interface{}) { + ls.errorf(p.plugin, format, v...) + p.logf(err, format, v...) +} + +// Fatal logs as log.Fatal and calls os.Exit(1). +func (p P) Fatal(v ...interface{}) { + ls.fatal(p.plugin, v...) + p.log(fatal, v...) + os.Exit(1) +} + +// Fatalf logs as log.Fatalf and calls os.Exit(1). +func (p P) Fatalf(format string, v ...interface{}) { + ls.fatalf(p.plugin, format, v...) + p.logf(fatal, format, v...) + os.Exit(1) +} diff --git a/plugin/pkg/log/plugin_test.go b/plugin/pkg/log/plugin_test.go new file mode 100644 index 0000000..b24caa4 --- /dev/null +++ b/plugin/pkg/log/plugin_test.go @@ -0,0 +1,21 @@ +package log + +import ( + "bytes" + golog "log" + "strings" + "testing" +) + +func TestPlugins(t *testing.T) { + var f bytes.Buffer + const ts = "test" + golog.SetOutput(&f) + + lg := NewWithPlugin("testplugin") + + lg.Info(ts) + if x := f.String(); !strings.Contains(x, "plugin/testplugin") { + t.Errorf("Expected log to be %s, got %s", info+ts, x) + } +} diff --git a/plugin/pkg/nonwriter/nonwriter.go b/plugin/pkg/nonwriter/nonwriter.go new file mode 100644 index 0000000..411e98a --- /dev/null +++ b/plugin/pkg/nonwriter/nonwriter.go @@ -0,0 +1,21 @@ +// Package nonwriter implements a dns.ResponseWriter that never writes, but captures the dns.Msg being written. +package nonwriter + +import ( + "github.com/miekg/dns" +) + +// Writer is a type of ResponseWriter that captures the message, but never writes to the client. +type Writer struct { + dns.ResponseWriter + Msg *dns.Msg +} + +// New makes and returns a new NonWriter. +func New(w dns.ResponseWriter) *Writer { return &Writer{ResponseWriter: w} } + +// WriteMsg records the message, but doesn't write it itself. +func (w *Writer) WriteMsg(res *dns.Msg) error { + w.Msg = res + return nil +} diff --git a/plugin/pkg/nonwriter/nonwriter_test.go b/plugin/pkg/nonwriter/nonwriter_test.go new file mode 100644 index 0000000..d8433af --- /dev/null +++ b/plugin/pkg/nonwriter/nonwriter_test.go @@ -0,0 +1,19 @@ +package nonwriter + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestNonWriter(t *testing.T) { + nw := New(nil) + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + if err := nw.WriteMsg(m); err != nil { + t.Errorf("Got error when writing to nonwriter: %s", err) + } + if x := nw.Msg.Question[0].Name; x != "example.org." { + t.Errorf("Expacted 'example.org.' got %q:", x) + } +} diff --git a/plugin/pkg/parse/host.go b/plugin/pkg/parse/host.go new file mode 100644 index 0000000..f90e4fc --- /dev/null +++ b/plugin/pkg/parse/host.go @@ -0,0 +1,125 @@ +package parse + +import ( + "errors" + "fmt" + "net" + "os" + "strings" + + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +// ErrNoNameservers is returned by HostPortOrFile if no servers can be parsed. +var ErrNoNameservers = errors.New("no nameservers found") + +// Strips the zone, but preserves any port that comes after the zone +func stripZone(host string) string { + if strings.Contains(host, "%") { + lastPercent := strings.LastIndex(host, "%") + newHost := host[:lastPercent] + return newHost + } + return host +} + +// HostPortOrFile parses the strings in s, each string can either be a +// address, [scheme://]address:port or a filename. The address part is checked +// and in case of filename a resolv.conf like file is (assumed) and parsed and +// the nameservers found are returned. +func HostPortOrFile(s ...string) ([]string, error) { + var servers []string + for _, h := range s { + trans, host := Transport(h) + if len(host) == 0 { + return servers, fmt.Errorf("invalid address: %q", h) + } + + if trans == transport.UNIX { + servers = append(servers, trans+"://"+host) + continue + } + + addr, _, err := net.SplitHostPort(host) + + if err != nil { + // Parse didn't work, it is not a addr:port combo + hostNoZone := stripZone(host) + if net.ParseIP(hostNoZone) == nil { + ss, err := tryFile(host) + if err == nil { + servers = append(servers, ss...) + continue + } + return servers, fmt.Errorf("not an IP address or file: %q", host) + } + var ss string + switch trans { + case transport.DNS: + ss = net.JoinHostPort(host, transport.Port) + case transport.TLS: + ss = transport.TLS + "://" + net.JoinHostPort(host, transport.TLSPort) + case transport.QUIC: + ss = transport.QUIC + "://" + net.JoinHostPort(host, transport.QUICPort) + case transport.GRPC: + ss = transport.GRPC + "://" + net.JoinHostPort(host, transport.GRPCPort) + case transport.HTTPS: + ss = transport.HTTPS + "://" + net.JoinHostPort(host, transport.HTTPSPort) + } + servers = append(servers, ss) + continue + } + + if net.ParseIP(stripZone(addr)) == nil { + ss, err := tryFile(host) + if err == nil { + servers = append(servers, ss...) + continue + } + return servers, fmt.Errorf("not an IP address or file: %q", host) + } + servers = append(servers, h) + } + if len(servers) == 0 { + return servers, ErrNoNameservers + } + return servers, nil +} + +// Try to open this is a file first. +func tryFile(s string) ([]string, error) { + c, err := dns.ClientConfigFromFile(s) + if err == os.ErrNotExist { + return nil, fmt.Errorf("failed to open file %q: %q", s, err) + } else if err != nil { + return nil, err + } + + servers := []string{} + for _, s := range c.Servers { + servers = append(servers, net.JoinHostPort(s, c.Port)) + } + return servers, nil +} + +// HostPort will check if the host part is a valid IP address, if the +// IP address is valid, but no port is found, defaultPort is added. +func HostPort(s, defaultPort string) (string, error) { + addr, port, err := net.SplitHostPort(s) + if port == "" { + port = defaultPort + } + if err != nil { + if net.ParseIP(s) == nil { + return "", fmt.Errorf("must specify an IP address: `%s'", s) + } + return net.JoinHostPort(s, port), nil + } + + if net.ParseIP(addr) == nil { + return "", fmt.Errorf("must specify an IP address: `%s'", addr) + } + return net.JoinHostPort(addr, port), nil +} diff --git a/plugin/pkg/parse/host_test.go b/plugin/pkg/parse/host_test.go new file mode 100644 index 0000000..0b5f6f1 --- /dev/null +++ b/plugin/pkg/parse/host_test.go @@ -0,0 +1,121 @@ +package parse + +import ( + "os" + "testing" + + "github.com/coredns/coredns/plugin/pkg/transport" +) + +func TestHostPortOrFile(t *testing.T) { + tests := []struct { + in string + expected string + shouldErr bool + }{ + { + "8.8.8.8", + "8.8.8.8:53", + false, + }, + { + "8.8.8.8:153", + "8.8.8.8:153", + false, + }, + { + "/etc/resolv.conf:53", + "", + true, + }, + { + "resolv.conf", + "127.0.0.1:53", + false, + }, + { + "fe80::1", + "[fe80::1]:53", + false, + }, + { + "fe80::1%ens3", + "[fe80::1%ens3]:53", + false, + }, + { + "[fd01::1]:153", + "[fd01::1]:153", + false, + }, + { + "[fd01::1%ens3]:153", + "[fd01::1%ens3]:153", + false, + }, + { + "8.9.1043", + "", + true, + }, + { + "unix:///var/run/g.sock", + "unix:///var/run/g.sock", + false, + }, + { + "unix://", + "", + true, + }, + } + + err := os.WriteFile("resolv.conf", []byte("nameserver 127.0.0.1\n"), 0600) + if err != nil { + t.Fatalf("Failed to write test resolv.conf") + } + defer os.Remove("resolv.conf") + + for i, tc := range tests { + got, err := HostPortOrFile(tc.in) + if err == nil && tc.shouldErr { + t.Errorf("Test %d, expected error, got nil", i) + continue + } + if err != nil && tc.shouldErr { + continue + } + if got[0] != tc.expected { + t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got[0]) + } + } +} + +func TestParseHostPort(t *testing.T) { + tests := []struct { + in string + expected string + shouldErr bool + }{ + {"8.8.8.8:53", "8.8.8.8:53", false}, + {"a.a.a.a:153", "", true}, + {"8.8.8.8", "8.8.8.8:53", false}, + {"8.8.8.8:", "8.8.8.8:53", false}, + {"8.8.8.8::53", "", true}, + {"resolv.conf", "", true}, + } + + for i, tc := range tests { + got, err := HostPort(tc.in, transport.Port) + if err == nil && tc.shouldErr { + t.Errorf("Test %d, expected error, got nil", i) + continue + } + if err != nil && !tc.shouldErr { + t.Errorf("Test %d, expected no error, got %q", i, err) + } + if got != tc.expected { + t.Errorf("Test %d, expected %q, got %q", i, tc.expected, got) + } + } +} diff --git a/plugin/pkg/parse/parse.go b/plugin/pkg/parse/parse.go new file mode 100644 index 0000000..300a57a --- /dev/null +++ b/plugin/pkg/parse/parse.go @@ -0,0 +1,38 @@ +// Package parse contains functions that can be used in the setup code for plugins. +package parse + +import ( + "fmt" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/transport" +) + +// TransferIn parses transfer statements: 'transfer from [address...]'. +func TransferIn(c *caddy.Controller) (froms []string, err error) { + if !c.NextArg() { + return nil, c.ArgErr() + } + value := c.Val() + switch value { + default: + return nil, c.Errf("unknown property %s", value) + case "from": + froms = c.RemainingArgs() + if len(froms) == 0 { + return nil, c.ArgErr() + } + for i := range froms { + if froms[i] != "*" { + normalized, err := HostPort(froms[i], transport.Port) + if err != nil { + return nil, err + } + froms[i] = normalized + } else { + return nil, fmt.Errorf("can't use '*' in transfer from") + } + } + } + return froms, nil +} diff --git a/plugin/pkg/parse/parse_test.go b/plugin/pkg/parse/parse_test.go new file mode 100644 index 0000000..4f253a9 --- /dev/null +++ b/plugin/pkg/parse/parse_test.go @@ -0,0 +1,59 @@ +package parse + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestTransferIn(t *testing.T) { + tests := []struct { + inputFileRules string + shouldErr bool + expectedFrom []string + }{ + { + `from 127.0.0.1`, + false, []string{"127.0.0.1:53"}, + }, + // OK transfer froms + { + `from 127.0.0.1 127.0.0.2`, + false, []string{"127.0.0.1:53", "127.0.0.2:53"}, + }, + // Bad transfer from garbage + { + `from !@#$%^&*()`, + true, []string{}, + }, + // Bad transfer from no args + { + `from`, + true, []string{}, + }, + // Bad transfer from * + { + `from *`, + true, []string{}, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + froms, err := TransferIn(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error %+v %+v", i, err, test) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } + + if test.expectedFrom != nil { + for j, got := range froms { + if got != test.expectedFrom[j] { + t.Fatalf("Test %d expected %v, got %v", i, test.expectedFrom[j], got) + } + } + } + } +} diff --git a/plugin/pkg/parse/transport.go b/plugin/pkg/parse/transport.go new file mode 100644 index 0000000..f0cf1c2 --- /dev/null +++ b/plugin/pkg/parse/transport.go @@ -0,0 +1,40 @@ +package parse + +import ( + "strings" + + "github.com/coredns/coredns/plugin/pkg/transport" +) + +// Transport returns the transport defined in s and a string where the +// transport prefix is removed (if there was any). If no transport is defined +// we default to TransportDNS +func Transport(s string) (trans string, addr string) { + switch { + case strings.HasPrefix(s, transport.TLS+"://"): + s = s[len(transport.TLS+"://"):] + return transport.TLS, s + + case strings.HasPrefix(s, transport.DNS+"://"): + s = s[len(transport.DNS+"://"):] + return transport.DNS, s + + case strings.HasPrefix(s, transport.QUIC+"://"): + s = s[len(transport.QUIC+"://"):] + return transport.QUIC, s + + case strings.HasPrefix(s, transport.GRPC+"://"): + s = s[len(transport.GRPC+"://"):] + return transport.GRPC, s + + case strings.HasPrefix(s, transport.HTTPS+"://"): + s = s[len(transport.HTTPS+"://"):] + + return transport.HTTPS, s + case strings.HasPrefix(s, transport.UNIX+"://"): + s = s[len(transport.UNIX+"://"):] + return transport.UNIX, s + } + + return transport.DNS, s +} diff --git a/plugin/pkg/parse/transport_test.go b/plugin/pkg/parse/transport_test.go new file mode 100644 index 0000000..d0e0fcd --- /dev/null +++ b/plugin/pkg/parse/transport_test.go @@ -0,0 +1,25 @@ +package parse + +import ( + "testing" + + "github.com/coredns/coredns/plugin/pkg/transport" +) + +func TestTransport(t *testing.T) { + for i, test := range []struct { + input string + expected string + }{ + {"dns://.:53", transport.DNS}, + {"2003::1/64.:53", transport.DNS}, + {"grpc://example.org:1443 ", transport.GRPC}, + {"tls://example.org ", transport.TLS}, + {"https://example.org ", transport.HTTPS}, + } { + actual, _ := Transport(test.input) + if actual != test.expected { + t.Errorf("Test %d: Expected %s but got %s", i, test.expected, actual) + } + } +} diff --git a/plugin/pkg/proxy/connect.go b/plugin/pkg/proxy/connect.go new file mode 100644 index 0000000..27385a4 --- /dev/null +++ b/plugin/pkg/proxy/connect.go @@ -0,0 +1,188 @@ +// Package proxy implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package proxy + +import ( + "context" + "errors" + "io" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*persistConn, bool, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + pc := <-t.ret + + if pc != nil { + connCacheHitsCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) + return pc, true, nil + } + connCacheMissesCount.WithLabelValues(t.proxyName, t.addr, proto).Add(1) + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err +} + +// Connect selects an upstream, sends the request and waits for a response. +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts Options) (*dns.Msg, error) { + start := time.Now() + + proto := "" + switch { + case opts.ForceTCP: // TCP flag has precedence over UDP flag + proto = "tcp" + case opts.PreferUDP: + proto = "udp" + default: + proto = state.Proto() + } + + pc, cached, err := p.transport.Dial(proto) + if err != nil { + return nil, err + } + + // Set buffer size correctly for this client. + pc.c.UDPSize = uint16(state.Size()) + if pc.c.UDPSize < 512 { + pc.c.UDPSize = 512 + } + + pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) + // records the origin Id before upstream. + originId := state.Req.Id + state.Req.Id = dns.Id() + defer func() { + state.Req.Id = originId + }() + + if err := pc.c.WriteMsg(state.Req); err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + return nil, err + } + + var ret *dns.Msg + pc.c.SetReadDeadline(time.Now().Add(p.readTimeout)) + for { + ret, err = pc.c.ReadMsg() + if err != nil { + if ret != nil && (state.Req.Id == ret.Id) && p.transport.transportTypeFromConn(pc) == typeUDP && shouldTruncateResponse(err) { + // For UDP, if the error is an overflow, we probably have an upstream misbehaving in some way. + // (e.g. sending >512 byte responses without an eDNS0 OPT RR). + // Instead of returning an error, return an empty response with TC bit set. This will make the + // client retry over TCP (if that's supported) or at least receive a clean + // error. The connection is still good so we break before the close. + + // Truncate the response. + ret = truncateResponse(ret) + break + } + + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + // recovery the origin Id after upstream. + if ret != nil { + ret.Id = originId + } + return ret, err + } + // drop out-of-order responses + if state.Req.Id == ret.Id { + break + } + } + // recovery the origin Id after upstream. + ret.Id = originId + + p.transport.Yield(pc) + + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) + } + + requestDuration.WithLabelValues(p.proxyName, p.addr, rc).Observe(time.Since(start).Seconds()) + + return ret, nil +} + +const cumulativeAvgWeight = 4 + +// Function to determine if a response should be truncated. +func shouldTruncateResponse(err error) bool { + // This is to handle a scenario in which upstream sets the TC bit, but doesn't truncate the response + // and we get ErrBuf instead of overflow. + if _, isDNSErr := err.(*dns.Error); isDNSErr && errors.Is(err, dns.ErrBuf) { + return true + } else if strings.Contains(err.Error(), "overflow") { + return true + } + return false +} + +// Function to return an empty response with TC (truncated) bit set. +func truncateResponse(response *dns.Msg) *dns.Msg { + // Clear out Answer, Extra, and Ns sections + response.Answer = nil + response.Extra = nil + response.Ns = nil + + // Set TC bit to indicate truncation. + response.Truncated = true + return response +} diff --git a/plugin/pkg/proxy/errors.go b/plugin/pkg/proxy/errors.go new file mode 100644 index 0000000..4612364 --- /dev/null +++ b/plugin/pkg/proxy/errors.go @@ -0,0 +1,26 @@ +package proxy + +import ( + "errors" +) + +var ( + // ErrNoHealthy means no healthy proxies left. + ErrNoHealthy = errors.New("no healthy proxies") + // ErrNoForward means no forwarder defined. + ErrNoForward = errors.New("no forwarder defined") + // ErrCachedClosed means cached connection was closed by peer. + ErrCachedClosed = errors.New("cached connection was closed by peer") +) + +// Options holds various Options that can be set. +type Options struct { + // ForceTCP use TCP protocol for upstream DNS request. Has precedence over PreferUDP flag + ForceTCP bool + // PreferUDP use UDP protocol for upstream DNS request. + PreferUDP bool + // HCRecursionDesired sets recursion desired flag for Proxy healthcheck requests + HCRecursionDesired bool + // HCDomain sets domain for Proxy healthcheck requests + HCDomain string +} diff --git a/plugin/pkg/proxy/health.go b/plugin/pkg/proxy/health.go new file mode 100644 index 0000000..4b4b4cc --- /dev/null +++ b/plugin/pkg/proxy/health.go @@ -0,0 +1,134 @@ +package proxy + +import ( + "crypto/tls" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +// HealthChecker checks the upstream health. +type HealthChecker interface { + Check(*Proxy) error + SetTLSConfig(*tls.Config) + GetTLSConfig() *tls.Config + SetRecursionDesired(bool) + GetRecursionDesired() bool + SetDomain(domain string) + GetDomain() string + SetTCPTransport() + GetReadTimeout() time.Duration + SetReadTimeout(time.Duration) + GetWriteTimeout() time.Duration + SetWriteTimeout(time.Duration) +} + +// dnsHc is a health checker for a DNS endpoint (DNS, and DoT). +type dnsHc struct { + c *dns.Client + recursionDesired bool + domain string + + proxyName string +} + +// NewHealthChecker returns a new HealthChecker based on transport. +func NewHealthChecker(proxyName, trans string, recursionDesired bool, domain string) HealthChecker { + switch trans { + case transport.DNS, transport.TLS: + c := new(dns.Client) + c.Net = "udp" + c.ReadTimeout = 1 * time.Second + c.WriteTimeout = 1 * time.Second + + return &dnsHc{ + c: c, + recursionDesired: recursionDesired, + domain: domain, + proxyName: proxyName, + } + } + + log.Warningf("No healthchecker for transport %q", trans) + return nil +} + +func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { + h.c.Net = "tcp-tls" + h.c.TLSConfig = cfg +} + +func (h *dnsHc) GetTLSConfig() *tls.Config { + return h.c.TLSConfig +} + +func (h *dnsHc) SetRecursionDesired(recursionDesired bool) { + h.recursionDesired = recursionDesired +} +func (h *dnsHc) GetRecursionDesired() bool { + return h.recursionDesired +} + +func (h *dnsHc) SetDomain(domain string) { + h.domain = domain +} +func (h *dnsHc) GetDomain() string { + return h.domain +} + +func (h *dnsHc) SetTCPTransport() { + h.c.Net = "tcp" +} + +func (h *dnsHc) GetReadTimeout() time.Duration { + return h.c.ReadTimeout +} + +func (h *dnsHc) SetReadTimeout(t time.Duration) { + h.c.ReadTimeout = t +} + +func (h *dnsHc) GetWriteTimeout() time.Duration { + return h.c.WriteTimeout +} + +func (h *dnsHc) SetWriteTimeout(t time.Duration) { + h.c.WriteTimeout = t +} + +// For HC, we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty +// replies are considered fails, basically anything else constitutes a healthy upstream. + +// Check is used as the up.Func in the up.Probe. +func (h *dnsHc) Check(p *Proxy) error { + err := h.send(p.addr) + if err != nil { + healthcheckFailureCount.WithLabelValues(p.proxyName, p.addr).Add(1) + p.incrementFails() + return err + } + + atomic.StoreUint32(&p.fails, 0) + return nil +} + +func (h *dnsHc) send(addr string) error { + ping := new(dns.Msg) + ping.SetQuestion(h.domain, dns.TypeNS) + ping.MsgHdr.RecursionDesired = h.recursionDesired + + m, _, err := h.c.Exchange(ping, addr) + // If we got a header, we're alright, basically only care about I/O errors 'n stuff. + if err != nil && m != nil { + // Silly check, something sane came back. + if m.Response || m.Opcode == dns.OpcodeQuery { + err = nil + } + } + + return err +} diff --git a/plugin/pkg/proxy/health_test.go b/plugin/pkg/proxy/health_test.go new file mode 100644 index 0000000..bb93d77 --- /dev/null +++ b/plugin/pkg/proxy/health_test.go @@ -0,0 +1,153 @@ +package proxy + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +func TestHealth(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker("TestHealth", transport.DNS, true, ".") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy("TestHealth", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthTCP(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker("TestHealthTCP", transport.DNS, true, ".") + hc.SetTCPTransport() + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy("TestHealthTCP", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthNoRecursion(t *testing.T) { + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." && r.RecursionDesired == false { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker("TestHealthNoRecursion", transport.DNS, false, ".") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy("TestHealthNoRecursion", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1) + } +} + +func TestHealthTimeout(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // timeout + }) + defer s.Close() + + hc := NewHealthChecker("TestHealthTimeout", transport.DNS, false, ".") + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy("TestHealthTimeout", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err == nil { + t.Errorf("expected error") + } +} + +func TestHealthDomain(t *testing.T) { + hcDomain := "example.org." + + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == hcDomain && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + hc := NewHealthChecker("TestHealthDomain", transport.DNS, true, hcDomain) + hc.SetReadTimeout(10 * time.Millisecond) + hc.SetWriteTimeout(10 * time.Millisecond) + + p := NewProxy("TestHealthDomain", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + err := hc.Check(p) + if err != nil { + t.Errorf("check failed: %v", err) + } + + time.Sleep(12 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) + } +} diff --git a/plugin/pkg/proxy/metrics.go b/plugin/pkg/proxy/metrics.go new file mode 100644 index 0000000..e4cae97 --- /dev/null +++ b/plugin/pkg/proxy/metrics.go @@ -0,0 +1,40 @@ +package proxy + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Variables declared for monitoring. +var ( + requestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time each request took.", + }, []string{"proxy_name", "to", "rcode"}) + + healthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "healthcheck_failures_total", + Help: "Counter of the number of failed healthchecks.", + }, []string{"proxy_name", "to"}) + + connCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "conn_cache_hits_total", + Help: "Counter of connection cache hits per upstream and protocol.", + }, []string{"proxy_name", "to", "proto"}) + + connCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "proxy", + Name: "conn_cache_misses_total", + Help: "Counter of connection cache misses per upstream and protocol.", + }, []string{"proxy_name", "to", "proto"}) +) diff --git a/plugin/pkg/proxy/persistent.go b/plugin/pkg/proxy/persistent.go new file mode 100644 index 0000000..49c9dd3 --- /dev/null +++ b/plugin/pkg/proxy/persistent.go @@ -0,0 +1,158 @@ +package proxy + +import ( + "crypto/tls" + "sort" + "time" + + "github.com/miekg/dns" +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// Transport hold the persistent cache. +type Transport struct { + avgDialTime int64 // kind of average time of dial time + conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + addr string + tlsConfig *tls.Config + proxyName string + + dial chan string + yield chan *persistConn + ret chan *persistConn + stop chan bool +} + +func newTransport(proxyName, addr string) *Transport { + t := &Transport{ + avgDialTime: int64(maxDialTimeout / 2), + conns: [typeTotalCount][]*persistConn{}, + expire: defaultExpire, + addr: addr, + dial: make(chan string), + yield: make(chan *persistConn), + ret: make(chan *persistConn), + stop: make(chan bool), + proxyName: proxyName, + } + return t +} + +// connManager manages the persistent connection cache for UDP and TCP. +func (t *Transport) connManager() { + ticker := time.NewTicker(defaultExpire) + defer ticker.Stop() +Wait: + for { + select { + case proto := <-t.dial: + transtype := stringToTransportType(proto) + // take the last used conn - complexity O(1) + if stack := t.conns[transtype]; len(stack) > 0 { + pc := stack[len(stack)-1] + if time.Since(pc.used) < t.expire { + // Found one, remove from pool and return this conn. + t.conns[transtype] = stack[:len(stack)-1] + t.ret <- pc + continue Wait + } + // clear entire cache if the last conn is expired + t.conns[transtype] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + } + t.ret <- nil + + case pc := <-t.yield: + transtype := t.transportTypeFromConn(pc) + t.conns[transtype] = append(t.conns[transtype], pc) + + case <-ticker.C: + t.cleanup(false) + + case <-t.stop: + t.cleanup(true) + close(t.ret) + return + } + } +} + +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *Transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for transtype, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[transtype] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[transtype] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} + +// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible. +const yieldTimeout = 25 * time.Millisecond + +// Yield returns the connection to transport for reuse. +func (t *Transport) Yield(pc *persistConn) { + pc.used = time.Now() // update used time + + // Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This + // blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning + // these connection is an optimization anyway. + select { + case t.yield <- pc: + return + case <-time.After(yieldTimeout): + return + } +} + +// Start starts the transport's connection manager. +func (t *Transport) Start() { go t.connManager() } + +// Stop stops the transport's connection manager. +func (t *Transport) Stop() { close(t.stop) } + +// SetExpire sets the connection expire time in transport. +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } + +// SetTLSConfig sets the TLS config in transport. +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } + +const ( + defaultExpire = 10 * time.Second + minDialTimeout = 1 * time.Second + maxDialTimeout = 30 * time.Second +) diff --git a/plugin/pkg/proxy/persistent_test.go b/plugin/pkg/proxy/persistent_test.go new file mode 100644 index 0000000..56d8371 --- /dev/null +++ b/plugin/pkg/proxy/persistent_test.go @@ -0,0 +1,109 @@ +package proxy + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + + "github.com/miekg/dns" +) + +func TestCached(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestCached", s.Addr) + tr.Start() + defer tr.Stop() + + c1, cache1, _ := tr.Dial("udp") + c2, cache2, _ := tr.Dial("udp") + + if cache1 || cache2 { + t.Errorf("Expected non-cached connection") + } + + tr.Yield(c1) + tr.Yield(c2) + c3, cached3, _ := tr.Dial("udp") + if !cached3 { + t.Error("Expected cached connection (c3)") + } + if c2 != c3 { + t.Error("Expected c2 == c3") + } + + tr.Yield(c3) + + // dial another protocol + c4, cached4, _ := tr.Dial("tcp") + if cached4 { + t.Errorf("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupByTimer(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestCleanupByTimer", s.Addr) + tr.SetExpire(100 * time.Millisecond) + tr.Start() + defer tr.Stop() + + c1, _, _ := tr.Dial("udp") + c2, _, _ := tr.Dial("udp") + tr.Yield(c1) + time.Sleep(10 * time.Millisecond) + tr.Yield(c2) + + time.Sleep(120 * time.Millisecond) + c3, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c3)") + } + tr.Yield(c3) + + time.Sleep(120 * time.Millisecond) + c4, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupAll(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport("TestCleanupAll", s.Addr) + + c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + + tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}} + + if len(tr.conns[typeUDP]) != 3 { + t.Error("Expected 3 connections") + } + tr.cleanup(true) + + if len(tr.conns[typeUDP]) > 0 { + t.Error("Expected no cached connections") + } +} diff --git a/plugin/pkg/proxy/proxy.go b/plugin/pkg/proxy/proxy.go new file mode 100644 index 0000000..99fb5df --- /dev/null +++ b/plugin/pkg/proxy/proxy.go @@ -0,0 +1,111 @@ +package proxy + +import ( + "crypto/tls" + "runtime" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/up" +) + +// Proxy defines an upstream host. +type Proxy struct { + fails uint32 + addr string + proxyName string + + transport *Transport + + readTimeout time.Duration + + // health checking + probe *up.Probe + health HealthChecker +} + +// NewProxy returns a new proxy. +func NewProxy(proxyName, addr, trans string) *Proxy { + p := &Proxy{ + addr: addr, + fails: 0, + probe: up.New(), + readTimeout: 2 * time.Second, + transport: newTransport(proxyName, addr), + health: NewHealthChecker(proxyName, trans, true, "."), + proxyName: proxyName, + } + + runtime.SetFinalizer(p, (*Proxy).finalizer) + return p +} + +func (p *Proxy) Addr() string { return p.addr } + +// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. +func (p *Proxy) SetTLSConfig(cfg *tls.Config) { + p.transport.SetTLSConfig(cfg) + p.health.SetTLSConfig(cfg) +} + +// SetExpire sets the expire duration in the lower p.transport. +func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } + +func (p *Proxy) GetHealthchecker() HealthChecker { + return p.health +} + +func (p *Proxy) Fails() uint32 { + return atomic.LoadUint32(&p.fails) +} + +// Healthcheck kicks of a round of health checks for this proxy. +func (p *Proxy) Healthcheck() { + if p.health == nil { + log.Warning("No healthchecker") + return + } + + p.probe.Do(func() error { + return p.health.Check(p) + }) +} + +// Down returns true if this proxy is down, i.e. has *more* fails than maxfails. +func (p *Proxy) Down(maxfails uint32) bool { + if maxfails == 0 { + return false + } + + fails := atomic.LoadUint32(&p.fails) + return fails > maxfails +} + +// Stop close stops the health checking goroutine. +func (p *Proxy) Stop() { p.probe.Stop() } +func (p *Proxy) finalizer() { p.transport.Stop() } + +// Start starts the proxy's healthchecking. +func (p *Proxy) Start(duration time.Duration) { + p.probe.Start(duration) + p.transport.Start() +} + +func (p *Proxy) SetReadTimeout(duration time.Duration) { + p.readTimeout = duration +} + +// incrementFails increments the number of fails safely. +func (p *Proxy) incrementFails() { + curVal := atomic.LoadUint32(&p.fails) + if curVal > curVal+1 { + // overflow occurred, do not update the counter again + return + } + atomic.AddUint32(&p.fails, 1) +} + +const ( + maxTimeout = 2 * time.Second +) diff --git a/plugin/pkg/proxy/proxy_test.go b/plugin/pkg/proxy/proxy_test.go new file mode 100644 index 0000000..03d10ce --- /dev/null +++ b/plugin/pkg/proxy/proxy_test.go @@ -0,0 +1,225 @@ +package proxy + +import ( + "context" + "crypto/tls" + "errors" + "math" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestProxy(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy("TestProxy", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + p.Start(5 * time.Second) + m := new(dns.Msg) + + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := request.Request{Req: m, W: rec} + + resp, err := p.Connect(context.Background(), req, Options{PreferUDP: true}) + if err != nil { + t.Errorf("Failed to connect to testdnsserver: %s", err) + } + + if x := resp.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +func TestProxyTLSFail(t *testing.T) { + // This is an udp/tcp test server, so we shouldn't reach it with TLS. + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy("TestProxyTLSFail", s.Addr, transport.TLS) + p.readTimeout = 10 * time.Millisecond + p.SetTLSConfig(&tls.Config{}) + p.Start(5 * time.Second) + m := new(dns.Msg) + + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + req := request.Request{Req: m, W: rec} + + _, err := p.Connect(context.Background(), req, Options{}) + if err == nil { + t.Fatal("Expected *not* to receive reply, but got one") + } +} + +func TestProtocolSelection(t *testing.T) { + p := NewProxy("TestProtocolSelection", "bad_address", transport.DNS) + p.readTimeout = 10 * time.Millisecond + + stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} + ctx := context.TODO() + + go func() { + p.Connect(ctx, stateUDP, Options{}) + p.Connect(ctx, stateUDP, Options{ForceTCP: true}) + p.Connect(ctx, stateUDP, Options{PreferUDP: true}) + p.Connect(ctx, stateUDP, Options{PreferUDP: true, ForceTCP: true}) + p.Connect(ctx, stateTCP, Options{}) + p.Connect(ctx, stateTCP, Options{ForceTCP: true}) + p.Connect(ctx, stateTCP, Options{PreferUDP: true}) + p.Connect(ctx, stateTCP, Options{PreferUDP: true, ForceTCP: true}) + }() + + for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { + proto := <-p.transport.dial + p.transport.ret <- nil + if proto != exp { + t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) + } + } +} + +func TestProxyIncrementFails(t *testing.T) { + var testCases = []struct { + name string + fails uint32 + expectFails uint32 + }{ + { + name: "increment fails counter overflows", + fails: math.MaxUint32, + expectFails: math.MaxUint32, + }, + { + name: "increment fails counter", + fails: 0, + expectFails: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + p := NewProxy("TestProxyIncrementFails", "bad_address", transport.DNS) + p.fails = tc.fails + p.incrementFails() + if p.fails != tc.expectFails { + t.Errorf("Expected fails to be %d, got %d", tc.expectFails, p.fails) + } + }) + } +} + +func TestCoreDNSOverflow(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + + answers := []dns.RR{ + test.A("example.org. IN A 127.0.0.1"), + test.A("example.org. IN A 127.0.0.2"), + test.A("example.org. IN A 127.0.0.3"), + test.A("example.org. IN A 127.0.0.4"), + test.A("example.org. IN A 127.0.0.5"), + test.A("example.org. IN A 127.0.0.6"), + test.A("example.org. IN A 127.0.0.7"), + test.A("example.org. IN A 127.0.0.8"), + test.A("example.org. IN A 127.0.0.9"), + test.A("example.org. IN A 127.0.0.10"), + test.A("example.org. IN A 127.0.0.11"), + test.A("example.org. IN A 127.0.0.12"), + test.A("example.org. IN A 127.0.0.13"), + test.A("example.org. IN A 127.0.0.14"), + test.A("example.org. IN A 127.0.0.15"), + test.A("example.org. IN A 127.0.0.16"), + test.A("example.org. IN A 127.0.0.17"), + test.A("example.org. IN A 127.0.0.18"), + test.A("example.org. IN A 127.0.0.19"), + test.A("example.org. IN A 127.0.0.20"), + } + ret.Answer = answers + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy("TestCoreDNSOverflow", s.Addr, transport.DNS) + p.readTimeout = 10 * time.Millisecond + p.Start(5 * time.Second) + defer p.Stop() + + // Test different connection modes + testConnection := func(proto string, options Options, expectTruncated bool) { + t.Helper() + + queryMsg := new(dns.Msg) + queryMsg.SetQuestion("example.org.", dns.TypeA) + + recorder := dnstest.NewRecorder(&test.ResponseWriter{}) + request := request.Request{Req: queryMsg, W: recorder} + + response, err := p.Connect(context.Background(), request, options) + if err != nil { + t.Errorf("Failed to connect to testdnsserver: %s", err) + } + + if response.Truncated != expectTruncated { + t.Errorf("Expected truncated response for %s, but got TC flag %v", proto, response.Truncated) + } + } + + // Test PreferUDP, expect truncated response + testConnection("PreferUDP", Options{PreferUDP: true}, true) + + // Test ForceTCP, expect no truncated response + testConnection("ForceTCP", Options{ForceTCP: true}, false) + + // Test No options specified, expect truncated response + testConnection("NoOptionsSpecified", Options{}, true) + + // Test both TCP and UDP provided, expect no truncated response + testConnection("BothTCPAndUDP", Options{PreferUDP: true, ForceTCP: true}, false) +} + +func TestShouldTruncateResponse(t *testing.T) { + testCases := []struct { + testname string + err error + expected bool + }{ + {"BadAlgorithm", dns.ErrAlg, false}, + {"BufferSizeTooSmall", dns.ErrBuf, true}, + {"OverflowUnpackingA", errors.New("overflow unpacking a"), true}, + {"OverflowingHeaderSize", errors.New("overflowing header size"), true}, + {"OverflowpackingA", errors.New("overflow packing a"), true}, + {"ErrSig", dns.ErrSig, false}, + } + + for _, tc := range testCases { + t.Run(tc.testname, func(t *testing.T) { + result := shouldTruncateResponse(tc.err) + if result != tc.expected { + t.Errorf("For testname '%v', expected %v but got %v", tc.testname, tc.expected, result) + } + }) + } +} diff --git a/plugin/pkg/proxy/type.go b/plugin/pkg/proxy/type.go new file mode 100644 index 0000000..10f3a46 --- /dev/null +++ b/plugin/pkg/proxy/type.go @@ -0,0 +1,39 @@ +package proxy + +import ( + "net" +) + +type transportType int + +const ( + typeUDP transportType = iota + typeTCP + typeTLS + typeTotalCount // keep this last +) + +func stringToTransportType(s string) transportType { + switch s { + case "udp": + return typeUDP + case "tcp": + return typeTCP + case "tcp-tls": + return typeTLS + } + + return typeUDP +} + +func (t *Transport) transportTypeFromConn(pc *persistConn) transportType { + if _, ok := pc.c.Conn.(*net.UDPConn); ok { + return typeUDP + } + + if t.tlsConfig == nil { + return typeTCP + } + + return typeTLS +} diff --git a/plugin/pkg/rand/rand.go b/plugin/pkg/rand/rand.go new file mode 100644 index 0000000..490f59b --- /dev/null +++ b/plugin/pkg/rand/rand.go @@ -0,0 +1,35 @@ +// Package rand is used for concurrency safe random number generator. +package rand + +import ( + "math/rand" + "sync" +) + +// Rand is used for concurrency safe random number generator. +type Rand struct { + m sync.Mutex + r *rand.Rand +} + +// New returns a new Rand from seed. +func New(seed int64) *Rand { + return &Rand{r: rand.New(rand.NewSource(seed))} +} + +// Int returns a non-negative pseudo-random int from the Source in Rand.r. +func (r *Rand) Int() int { + r.m.Lock() + v := r.r.Int() + r.m.Unlock() + return v +} + +// Perm returns, as a slice of n ints, a pseudo-random permutation of the +// integers in the half-open interval [0,n) from the Source in Rand.r. +func (r *Rand) Perm(n int) []int { + r.m.Lock() + v := r.r.Perm(n) + r.m.Unlock() + return v +} diff --git a/plugin/pkg/rcode/rcode.go b/plugin/pkg/rcode/rcode.go new file mode 100644 index 0000000..d221bcb --- /dev/null +++ b/plugin/pkg/rcode/rcode.go @@ -0,0 +1,15 @@ +package rcode + +import ( + "strconv" + + "github.com/miekg/dns" +) + +// ToString convert the rcode to the official DNS string, or to "RCODE"+value if the RCODE value is unknown. +func ToString(rcode int) string { + if str, ok := dns.RcodeToString[rcode]; ok { + return str + } + return "RCODE" + strconv.Itoa(rcode) +} diff --git a/plugin/pkg/rcode/rcode_test.go b/plugin/pkg/rcode/rcode_test.go new file mode 100644 index 0000000..bfca32f --- /dev/null +++ b/plugin/pkg/rcode/rcode_test.go @@ -0,0 +1,29 @@ +package rcode + +import ( + "testing" + + "github.com/miekg/dns" +) + +func TestToString(t *testing.T) { + tests := []struct { + in int + expected string + }{ + { + dns.RcodeSuccess, + "NOERROR", + }, + { + 28, + "RCODE28", + }, + } + for i, test := range tests { + got := ToString(test.in) + if got != test.expected { + t.Errorf("Test %d, expected %s, got %s", i, test.expected, got) + } + } +} diff --git a/plugin/pkg/replacer/replacer.go b/plugin/pkg/replacer/replacer.go new file mode 100644 index 0000000..4572443 --- /dev/null +++ b/plugin/pkg/replacer/replacer.go @@ -0,0 +1,284 @@ +package replacer + +import ( + "context" + "strconv" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Replacer replaces labels for values in strings. +type Replacer struct{} + +// New makes a new replacer. This only needs to be called once in the setup and +// then call Replace for each incoming message. A replacer is safe for concurrent use. +func New() Replacer { + return Replacer{} +} + +// Replace performs a replacement of values on s and returns the string with the replaced values. +func (r Replacer) Replace(ctx context.Context, state request.Request, rr *dnstest.Recorder, s string) string { + return loadFormat(s).Replace(ctx, state, rr) +} + +const ( + headerReplacer = "{>" + // EmptyValue is the default empty value. + EmptyValue = "-" +) + +// labels are all supported labels that can be used in the default Replacer. +var labels = map[string]struct{}{ + "{type}": {}, + "{name}": {}, + "{class}": {}, + "{proto}": {}, + "{size}": {}, + "{remote}": {}, + "{port}": {}, + "{local}": {}, + // Header values. + headerReplacer + "id}": {}, + headerReplacer + "opcode}": {}, + headerReplacer + "do}": {}, + headerReplacer + "bufsize}": {}, + // Recorded replacements. + "{rcode}": {}, + "{rsize}": {}, + "{duration}": {}, + headerReplacer + "rflags}": {}, +} + +// appendValue appends the current value of label. +func appendValue(b []byte, state request.Request, rr *dnstest.Recorder, label string) []byte { + switch label { + // Recorded replacements. + case "{rcode}": + if rr == nil || rr.Msg == nil { + return append(b, EmptyValue...) + } + if rcode := dns.RcodeToString[rr.Rcode]; rcode != "" { + return append(b, rcode...) + } + return strconv.AppendInt(b, int64(rr.Rcode), 10) + case "{rsize}": + if rr == nil { + return append(b, EmptyValue...) + } + return strconv.AppendInt(b, int64(rr.Len), 10) + case "{duration}": + if rr == nil { + return append(b, EmptyValue...) + } + secs := time.Since(rr.Start).Seconds() + return append(strconv.AppendFloat(b, secs, 'f', -1, 64), 's') + case headerReplacer + "rflags}": + if rr != nil && rr.Msg != nil { + return appendFlags(b, rr.Msg.MsgHdr) + } + return append(b, EmptyValue...) + } + + if (request.Request{}) == state { + return append(b, EmptyValue...) + } + + switch label { + case "{type}": + return append(b, state.Type()...) + case "{name}": + return append(b, state.Name()...) + case "{class}": + return append(b, state.Class()...) + case "{proto}": + return append(b, state.Proto()...) + case "{size}": + return strconv.AppendInt(b, int64(state.Req.Len()), 10) + case "{remote}": + return appendAddrToRFC3986(b, state.IP()) + case "{port}": + return append(b, state.Port()...) + case "{local}": + return appendAddrToRFC3986(b, state.LocalIP()) + // Header placeholders (case-insensitive). + case headerReplacer + "id}": + return strconv.AppendInt(b, int64(state.Req.Id), 10) + case headerReplacer + "opcode}": + return strconv.AppendInt(b, int64(state.Req.Opcode), 10) + case headerReplacer + "do}": + return strconv.AppendBool(b, state.Do()) + case headerReplacer + "bufsize}": + return strconv.AppendInt(b, int64(state.Size()), 10) + default: + return append(b, EmptyValue...) + } +} + +// appendFlags checks all header flags and appends those +// that are set as a string separated with commas +func appendFlags(b []byte, h dns.MsgHdr) []byte { + origLen := len(b) + if h.Response { + b = append(b, "qr,"...) + } + if h.Authoritative { + b = append(b, "aa,"...) + } + if h.Truncated { + b = append(b, "tc,"...) + } + if h.RecursionDesired { + b = append(b, "rd,"...) + } + if h.RecursionAvailable { + b = append(b, "ra,"...) + } + if h.Zero { + b = append(b, "z,"...) + } + if h.AuthenticatedData { + b = append(b, "ad,"...) + } + if h.CheckingDisabled { + b = append(b, "cd,"...) + } + if n := len(b); n > origLen { + return b[:n-1] // trim trailing ',' + } + return b +} + +// appendAddrToRFC3986 will add brackets to the address if it is an IPv6 address. +func appendAddrToRFC3986(b []byte, addr string) []byte { + if strings.IndexByte(addr, ':') != -1 { + b = append(b, '[') + b = append(b, addr...) + b = append(b, ']') + } else { + b = append(b, addr...) + } + return b +} + +type nodeType int + +const ( + typeLabel nodeType = iota // "{type}" + typeLiteral // "foo" + typeMetadata // "{/metadata}" +) + +// A node represents a segment of a parsed format. For example: "A {type}" +// contains two nodes: "A " (literal); and "{type}" (label). +type node struct { + value string // Literal value, label or metadata label + typ nodeType +} + +// A replacer is an ordered list of all the nodes in a format. +type replacer []node + +func parseFormat(s string) replacer { + // Assume there is a literal between each label - its cheaper to over + // allocate once than allocate twice. + rep := make(replacer, 0, strings.Count(s, "{")*2) + for { + // We find the right bracket then backtrack to find the left bracket. + // This allows us to handle formats like: "{ {foo} }". + j := strings.IndexByte(s, '}') + if j < 0 { + break + } + i := strings.LastIndexByte(s[:j], '{') + if i < 0 { + // Handle: "A } {foo}" by treating "A }" as a literal + rep = append(rep, node{ + value: s[:j+1], + typ: typeLiteral, + }) + s = s[j+1:] + continue + } + + val := s[i : j+1] + var typ nodeType + switch _, ok := labels[val]; { + case ok: + typ = typeLabel + case strings.HasPrefix(val, "{/"): + // Strip "{/}" from metadata labels + val = val[2 : len(val)-1] + typ = typeMetadata + default: + // Given: "A {X}" val is "{X}" expand it to the whole literal. + val = s[:j+1] + typ = typeLiteral + } + + // Append any leading literal. Given "A {type}" the literal is "A " + if i != 0 && typ != typeLiteral { + rep = append(rep, node{ + value: s[:i], + typ: typeLiteral, + }) + } + rep = append(rep, node{ + value: val, + typ: typ, + }) + s = s[j+1:] + } + if len(s) != 0 { + rep = append(rep, node{ + value: s, + typ: typeLiteral, + }) + } + return rep +} + +var replacerCache sync.Map // map[string]replacer + +func loadFormat(s string) replacer { + if v, ok := replacerCache.Load(s); ok { + return v.(replacer) + } + v, _ := replacerCache.LoadOrStore(s, parseFormat(s)) + return v.(replacer) +} + +// bufPool stores pointers to scratch buffers. +var bufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 0, 256) + }, +} + +func (r replacer) Replace(ctx context.Context, state request.Request, rr *dnstest.Recorder) string { + b := bufPool.Get().([]byte) + for _, s := range r { + switch s.typ { + case typeLabel: + b = appendValue(b, state, rr, s.value) + case typeLiteral: + b = append(b, s.value...) + case typeMetadata: + if fm := metadata.ValueFunc(ctx, s.value); fm != nil { + b = append(b, fm()...) + } else { + b = append(b, EmptyValue...) + } + } + } + s := string(b) + //nolint:staticcheck + bufPool.Put(b[:0]) + return s +} diff --git a/plugin/pkg/replacer/replacer_test.go b/plugin/pkg/replacer/replacer_test.go new file mode 100644 index 0000000..aa8ac6f --- /dev/null +++ b/plugin/pkg/replacer/replacer_test.go @@ -0,0 +1,448 @@ +package replacer + +import ( + "context" + "reflect" + "strings" + "testing" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// This is the default format used by the log package +const CommonLogFormat = `{remote}:{port} - {>id} "{type} {class} {name} {proto} {size} {>do} {>bufsize}" {rcode} {>rflags} {rsize} {duration}` + +func TestReplacer(t *testing.T) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.MsgHdr.AuthenticatedData = true + state := request.Request{W: w, Req: r} + + replacer := New() + + if x := replacer.Replace(context.TODO(), state, nil, "{type}"); x != "HINFO" { + t.Errorf("Expected type to be HINFO, got %q", x) + } + if x := replacer.Replace(context.TODO(), state, nil, "{name}"); x != "example.org." { + t.Errorf("Expected request name to be example.org., got %q", x) + } + if x := replacer.Replace(context.TODO(), state, nil, "{size}"); x != "29" { + t.Errorf("Expected size to be 29, got %q", x) + } +} + +func TestParseFormat(t *testing.T) { + type formatTest struct { + Format string + Expected replacer + } + tests := []formatTest{ + { + Format: "", + Expected: replacer{}, + }, + { + Format: "A", + Expected: replacer{ + {"A", typeLiteral}, + }, + }, + { + Format: "A {A}", + Expected: replacer{ + {"A {A}", typeLiteral}, + }, + }, + { + Format: "{{remote}}", + Expected: replacer{ + {"{", typeLiteral}, + {"{remote}", typeLabel}, + {"}", typeLiteral}, + }, + }, + { + Format: "{ A {remote} A }", + Expected: replacer{ + {"{ A ", typeLiteral}, + {"{remote}", typeLabel}, + {" A }", typeLiteral}, + }, + }, + { + Format: "{remote}}", + Expected: replacer{ + {"{remote}", typeLabel}, + {"}", typeLiteral}, + }, + }, + { + Format: "{{remote}", + Expected: replacer{ + {"{", typeLiteral}, + {"{remote}", typeLabel}, + }, + }, + { + Format: `Foo } {remote}`, + Expected: replacer{ + // we don't do any optimizations to join adjacent literals + {"Foo }", typeLiteral}, + {" ", typeLiteral}, + {"{remote}", typeLabel}, + }, + }, + { + Format: `{ Foo`, + Expected: replacer{ + {"{ Foo", typeLiteral}, + }, + }, + { + Format: `} Foo`, + Expected: replacer{ + {"}", typeLiteral}, + {" Foo", typeLiteral}, + }, + }, + { + Format: "A { {remote} {type} {/meta1} } B", + Expected: replacer{ + {"A { ", typeLiteral}, + {"{remote}", typeLabel}, + {" ", typeLiteral}, + {"{type}", typeLabel}, + {" ", typeLiteral}, + {"meta1", typeMetadata}, + {" }", typeLiteral}, + {" B", typeLiteral}, + }, + }, + { + Format: `LOG {remote}:{port} - {>id} "{type} {class} {name} {proto} ` + + `{size} {>do} {>bufsize}" {rcode} {>rflags} {rsize} {/meta1}-{/meta2} ` + + `{duration} END OF LINE`, + Expected: replacer{ + {"LOG ", typeLiteral}, + {"{remote}", typeLabel}, + {":", typeLiteral}, + {"{port}", typeLabel}, + {" - ", typeLiteral}, + {"{>id}", typeLabel}, + {` "`, typeLiteral}, + {"{type}", typeLabel}, + {" ", typeLiteral}, + {"{class}", typeLabel}, + {" ", typeLiteral}, + {"{name}", typeLabel}, + {" ", typeLiteral}, + {"{proto}", typeLabel}, + {" ", typeLiteral}, + {"{size}", typeLabel}, + {" ", typeLiteral}, + {"{>do}", typeLabel}, + {" ", typeLiteral}, + {"{>bufsize}", typeLabel}, + {`" `, typeLiteral}, + {"{rcode}", typeLabel}, + {" ", typeLiteral}, + {"{>rflags}", typeLabel}, + {" ", typeLiteral}, + {"{rsize}", typeLabel}, + {" ", typeLiteral}, + {"meta1", typeMetadata}, + {"-", typeLiteral}, + {"meta2", typeMetadata}, + {" ", typeLiteral}, + {"{duration}", typeLabel}, + {" END OF LINE", typeLiteral}, + }, + }, + } + for i, x := range tests { + r := parseFormat(x.Format) + if !reflect.DeepEqual(r, x.Expected) { + t.Errorf("%d: Expected:\n\t%+v\nGot:\n\t%+v", i, x.Expected, r) + } + } +} + +func TestParseFormatNodes(t *testing.T) { + // If we parse the format successfully the result of joining all the + // segments should match the original format. + formats := []string{ + "", + "msg", + "{remote}", + "{remote}", + "{{remote}", + "{{remote}}", + "{{remote}} A", + CommonLogFormat, + CommonLogFormat + " FOO} {BAR}", + "A " + CommonLogFormat + " FOO} {BAR}", + "A " + CommonLogFormat + " {/meta}", + } + join := func(r replacer) string { + a := make([]string, len(r)) + for i, n := range r { + if n.typ == typeMetadata { + a[i] = "{/" + n.value + "}" + } else { + a[i] = n.value + } + } + return strings.Join(a, "") + } + for _, format := range formats { + r := parseFormat(format) + s := join(r) + if s != format { + t.Errorf("Expected format to be: '%s' got: '%s'", format, s) + } + } +} + +func TestLabels(t *testing.T) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.Id = 1053 + r.AuthenticatedData = true + r.CheckingDisabled = true + w.WriteMsg(r) + state := request.Request{W: w, Req: r} + + replacer := New() + ctx := context.TODO() + + // This couples the test very tightly to the code, but so be it. + expect := map[string]string{ + "{type}": "HINFO", + "{name}": "example.org.", + "{class}": "IN", + "{proto}": "udp", + "{size}": "29", + "{remote}": "10.240.0.1", + "{port}": "40212", + "{local}": "127.0.0.1", + headerReplacer + "id}": "1053", + headerReplacer + "opcode}": "0", + headerReplacer + "do}": "false", + headerReplacer + "bufsize}": "512", + "{rcode}": "NOERROR", + "{rsize}": "29", + "{duration}": "0", + headerReplacer + "rflags}": "rd,ad,cd", + } + if len(expect) != len(labels) { + t.Fatalf("Expect %d labels, got %d", len(expect), len(labels)) + } + + for lbl := range labels { + repl := replacer.Replace(ctx, state, w, lbl) + if lbl == "{duration}" { + if repl[len(repl)-1] != 's' { + t.Errorf("Expected seconds, got %q", repl) + } + continue + } + if repl != expect[lbl] { + t.Errorf("Expected value %q, got %q", expect[lbl], repl) + } + + // test empty state and nil recorder won't panic + repl_empty := replacer.Replace(ctx, request.Request{}, nil, lbl) + if repl_empty != EmptyValue { + t.Errorf("Expected empty value %q, got %q", EmptyValue, repl_empty) + } + } +} + +func BenchmarkReplacer(b *testing.B) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.MsgHdr.AuthenticatedData = true + state := request.Request{W: w, Req: r} + + b.ResetTimer() + b.ReportAllocs() + + replacer := New() + for i := 0; i < b.N; i++ { + replacer.Replace(context.TODO(), state, nil, "{type} {name} {size}") + } +} + +func BenchmarkReplacer_CommonLogFormat(b *testing.B) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.Id = 1053 + r.AuthenticatedData = true + r.CheckingDisabled = true + r.MsgHdr.AuthenticatedData = true + w.WriteMsg(r) + state := request.Request{W: w, Req: r} + + replacer := New() + ctxt := context.TODO() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + replacer.Replace(ctxt, state, w, CommonLogFormat) + } +} + +func BenchmarkParseFormat(b *testing.B) { + for i := 0; i < b.N; i++ { + parseFormat(CommonLogFormat) + } +} + +type testProvider map[string]metadata.Func + +func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context { + for k, v := range tp { + metadata.SetValueFunc(ctx, k, v) + } + return ctx +} + +type testHandler struct{ ctx context.Context } + +func (m *testHandler) Name() string { return "test" } + +func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m.ctx = ctx + return 0, nil +} + +func TestMetadataReplacement(t *testing.T) { + tests := []struct { + expr string + result string + }{ + {"{/test/meta2}", "two"}, + {"{/test/meta2} {/test/key4}", "two -"}, + {"{/test/meta2} {/test/meta3}", "two three"}, + } + + next := &testHandler{} + m := metadata.Metadata{ + Zones: []string{"."}, + Providers: []metadata.Provider{ + testProvider{"test/meta2": func() string { return "two" }}, + testProvider{"test/meta3": func() string { return "three" }}, + }, + Next: next, + } + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + + ctx := m.Collect(context.TODO(), request.Request{W: w, Req: r}) + + repl := New() + state := request.Request{W: w, Req: r} + + for i, ts := range tests { + r := repl.Replace(ctx, state, nil, ts.expr) + if r != ts.result { + t.Errorf("Test %d - expr : %s, expected %q, got %q", i, ts.expr, ts.result, r) + } + } +} + +func TestMetadataMalformed(t *testing.T) { + tests := []struct { + expr string + result string + }{ + {"{/test/meta2", "{/test/meta2"}, + {"{test/meta2} {/test/meta4}", "{test/meta2} -"}, + {"{test}", "{test}"}, + } + + next := &testHandler{} + m := metadata.Metadata{ + Zones: []string{"."}, + Providers: []metadata.Provider{testProvider{"test/meta2": func() string { return "two" }}}, + Next: next, + } + + m.ServeDNS(context.TODO(), &test.ResponseWriter{}, new(dns.Msg)) + ctx := next.ctx // important because the m.ServeDNS has only now populated the context + + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + + repl := New() + state := request.Request{W: w, Req: r} + + for i, ts := range tests { + r := repl.Replace(ctx, state, nil, ts.expr) + if r != ts.result { + t.Errorf("Test %d - expr : %s, expected %q, got %q", i, ts.expr, ts.result, r) + } + } +} + +func TestNoResponseWasWritten(t *testing.T) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + r := new(dns.Msg) + r.SetQuestion("example.org.", dns.TypeHINFO) + r.Id = 1053 + r.AuthenticatedData = true + r.CheckingDisabled = true + state := request.Request{W: w, Req: r} + + replacer := New() + ctx := context.TODO() + + // This couples the test very tightly to the code, but so be it. + expect := map[string]string{ + "{type}": "HINFO", + "{name}": "example.org.", + "{class}": "IN", + "{proto}": "udp", + "{size}": "29", + "{remote}": "10.240.0.1", + "{port}": "40212", + "{local}": "127.0.0.1", + headerReplacer + "id}": "1053", + headerReplacer + "opcode}": "0", + headerReplacer + "do}": "false", + headerReplacer + "bufsize}": "512", + "{rcode}": "-", + "{rsize}": "0", + "{duration}": "0", + headerReplacer + "rflags}": "-", + } + if len(expect) != len(labels) { + t.Fatalf("Expect %d labels, got %d", len(expect), len(labels)) + } + + for lbl := range labels { + repl := replacer.Replace(ctx, state, w, lbl) + if lbl == "{duration}" { + if repl[len(repl)-1] != 's' { + t.Errorf("Expected seconds, got %q", repl) + } + continue + } + if repl != expect[lbl] { + t.Errorf("Expected value %q, got %q", expect[lbl], repl) + } + } +} diff --git a/plugin/pkg/response/classify.go b/plugin/pkg/response/classify.go new file mode 100644 index 0000000..2e705cb --- /dev/null +++ b/plugin/pkg/response/classify.go @@ -0,0 +1,61 @@ +package response + +import "fmt" + +// Class holds sets of Types +type Class int + +const ( + // All is a meta class encompassing all the classes. + All Class = iota + // Success is a class for a successful response. + Success + // Denial is a class for denying existence (NXDOMAIN, or a nodata: type does not exist) + Denial + // Error is a class for errors, right now defined as not Success and not Denial + Error +) + +func (c Class) String() string { + switch c { + case All: + return "all" + case Success: + return "success" + case Denial: + return "denial" + case Error: + return "error" + } + return "" +} + +// ClassFromString returns the class from the string s. If not class matches +// the All class and an error are returned +func ClassFromString(s string) (Class, error) { + switch s { + case "all": + return All, nil + case "success": + return Success, nil + case "denial": + return Denial, nil + case "error": + return Error, nil + } + return All, fmt.Errorf("invalid Class: %s", s) +} + +// Classify classifies the Type t, it returns its Class. +func Classify(t Type) Class { + switch t { + case NoError, Delegation: + return Success + case NameError, NoData: + return Denial + case OtherError: + fallthrough + default: + return Error + } +} diff --git a/plugin/pkg/response/typify.go b/plugin/pkg/response/typify.go new file mode 100644 index 0000000..df314d4 --- /dev/null +++ b/plugin/pkg/response/typify.go @@ -0,0 +1,151 @@ +package response + +import ( + "fmt" + "time" + + "github.com/miekg/dns" +) + +// Type is the type of the message. +type Type int + +const ( + // NoError indicates a positive reply + NoError Type = iota + // NameError is a NXDOMAIN in header, SOA in auth. + NameError + // ServerError is a set of errors we want to cache, for now it contains SERVFAIL and NOTIMPL. + ServerError + // NoData indicates name found, but not the type: NOERROR in header, SOA in auth. + NoData + // Delegation is a msg with a pointer to another nameserver: NOERROR in header, NS in auth, optionally fluff in additional (not checked). + Delegation + // Meta indicates a meta message, NOTIFY, or a transfer: qType is IXFR or AXFR. + Meta + // Update is an dynamic update message. + Update + // OtherError indicates any other error: don't cache these. + OtherError +) + +var toString = map[Type]string{ + NoError: "NOERROR", + NameError: "NXDOMAIN", + ServerError: "SERVERERROR", + NoData: "NODATA", + Delegation: "DELEGATION", + Meta: "META", + Update: "UPDATE", + OtherError: "OTHERERROR", +} + +func (t Type) String() string { return toString[t] } + +// TypeFromString returns the type from the string s. If not type matches +// the OtherError type and an error are returned. +func TypeFromString(s string) (Type, error) { + for t, str := range toString { + if s == str { + return t, nil + } + } + return NoError, fmt.Errorf("invalid Type: %s", s) +} + +// Typify classifies a message, it returns the Type. +func Typify(m *dns.Msg, t time.Time) (Type, *dns.OPT) { + if m == nil { + return OtherError, nil + } + opt := m.IsEdns0() + do := false + if opt != nil { + do = opt.Do() + } + + if m.Opcode == dns.OpcodeUpdate { + return Update, opt + } + + // Check transfer and update first + if m.Opcode == dns.OpcodeNotify { + return Meta, opt + } + + if len(m.Question) > 0 { + if m.Question[0].Qtype == dns.TypeAXFR || m.Question[0].Qtype == dns.TypeIXFR { + return Meta, opt + } + } + + // If our message contains any expired sigs and we care about that, we should return expired + if do { + if expired := typifyExpired(m, t); expired { + return OtherError, opt + } + } + + if len(m.Answer) > 0 && m.Rcode == dns.RcodeSuccess { + return NoError, opt + } + + soa := false + ns := 0 + for _, r := range m.Ns { + if r.Header().Rrtype == dns.TypeSOA { + soa = true + continue + } + if r.Header().Rrtype == dns.TypeNS { + ns++ + } + } + + if soa && m.Rcode == dns.RcodeSuccess { + return NoData, opt + } + if soa && m.Rcode == dns.RcodeNameError { + return NameError, opt + } + + if m.Rcode == dns.RcodeServerFailure || m.Rcode == dns.RcodeNotImplemented { + return ServerError, opt + } + + if ns > 0 && m.Rcode == dns.RcodeSuccess { + return Delegation, opt + } + + if m.Rcode == dns.RcodeSuccess { + return NoError, opt + } + + return OtherError, opt +} + +func typifyExpired(m *dns.Msg, t time.Time) bool { + if expired := typifyExpiredRRSIG(m.Answer, t); expired { + return true + } + if expired := typifyExpiredRRSIG(m.Ns, t); expired { + return true + } + if expired := typifyExpiredRRSIG(m.Extra, t); expired { + return true + } + return false +} + +func typifyExpiredRRSIG(rrs []dns.RR, t time.Time) bool { + for _, r := range rrs { + if r.Header().Rrtype != dns.TypeRRSIG { + continue + } + ok := r.(*dns.RRSIG).ValidityPeriod(t) + if !ok { + return true + } + } + return false +} diff --git a/plugin/pkg/response/typify_test.go b/plugin/pkg/response/typify_test.go new file mode 100644 index 0000000..3d9abdf --- /dev/null +++ b/plugin/pkg/response/typify_test.go @@ -0,0 +1,101 @@ +package response + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestTypifyNilMsg(t *testing.T) { + var m *dns.Msg + + ty, _ := Typify(m, time.Now().UTC()) + if ty != OtherError { + t.Errorf("Message wrongly typified, expected OtherError, got %s", ty) + } +} + +func TestTypifyDelegation(t *testing.T) { + m := delegationMsg() + mt, _ := Typify(m, time.Now().UTC()) + if mt != Delegation { + t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt) + } +} + +func TestTypifyRRSIG(t *testing.T) { + now, _ := time.Parse(time.UnixDate, "Fri Apr 21 10:51:21 BST 2017") + utc := now.UTC() + + m := delegationMsgRRSIGOK() + if mt, _ := Typify(m, utc); mt != Delegation { + t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt) + } + + // Still a Delegation because EDNS0 OPT DO bool is not set, so we won't check the sigs. + m = delegationMsgRRSIGFail() + if mt, _ := Typify(m, utc); mt != Delegation { + t.Errorf("Message is wrongly typified, expected Delegation, got %s", mt) + } + + m = delegationMsgRRSIGFail() + m.Extra = append(m.Extra, test.OPT(4096, true)) + if mt, _ := Typify(m, utc); mt != OtherError { + t.Errorf("Message is wrongly typified, expected OtherError, got %s", mt) + } +} + +func TestTypifyImpossible(t *testing.T) { + // create impossible message that denies its own existence + m := new(dns.Msg) + m.SetQuestion("bar.www.example.org.", dns.TypeAAAA) + m.Rcode = dns.RcodeNameError // name does not exist + m.Answer = []dns.RR{test.CNAME("bar.www.example.org. IN CNAME foo.example.org.")} // but we add a cname with the name! + mt, _ := Typify(m, time.Now().UTC()) + if mt != OtherError { + t.Errorf("Impossible message not typified as OtherError, got %s", mt) + } +} + +func TestTypifyRefused(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion("foo.example.org.", dns.TypeA) + m.Rcode = dns.RcodeRefused + mt, _ := Typify(m, time.Now().UTC()) + if mt != OtherError { + t.Errorf("Refused message not typified as OtherError, got %s", mt) + } +} + +func delegationMsg() *dns.Msg { + return &dns.Msg{ + Ns: []dns.RR{ + test.NS("miek.nl. 3600 IN NS linode.atoom.net."), + test.NS("miek.nl. 3600 IN NS ns-ext.nlnetlabs.nl."), + test.NS("miek.nl. 3600 IN NS omval.tednet.nl."), + }, + Extra: []dns.RR{ + test.A("omval.tednet.nl. 3600 IN A 185.49.141.42"), + test.AAAA("omval.tednet.nl. 3600 IN AAAA 2a04:b900:0:100::42"), + }, + } +} + +func delegationMsgRRSIGOK() *dns.Msg { + del := delegationMsg() + del.Ns = append(del.Ns, + test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20170521031301 20170421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), + ) + return del +} + +func delegationMsgRRSIGFail() *dns.Msg { + del := delegationMsg() + del.Ns = append(del.Ns, + test.RRSIG("miek.nl. 1800 IN RRSIG NS 8 2 1800 20160521031301 20160421031301 12051 miek.nl. PIUu3TKX/sB/N1n1E1yWxHHIcPnc2q6Wq9InShk+5ptRqChqKdZNMLDm gCq+1bQAZ7jGvn2PbwTwE65JzES7T+hEiqR5PU23DsidvZyClbZ9l0xG JtKwgzGXLtUHxp4xv/Plq+rq/7pOG61bNCxRyS7WS7i7QcCCWT1BCcv+ wZ0="), + ) + return del +} diff --git a/plugin/pkg/reuseport/listen_no_reuseport.go b/plugin/pkg/reuseport/listen_no_reuseport.go new file mode 100644 index 0000000..1018a9b --- /dev/null +++ b/plugin/pkg/reuseport/listen_no_reuseport.go @@ -0,0 +1,13 @@ +//go:build !go1.11 || (!aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd) + +package reuseport + +import "net" + +// Listen is a wrapper around net.Listen. +func Listen(network, addr string) (net.Listener, error) { return net.Listen(network, addr) } + +// ListenPacket is a wrapper around net.ListenPacket. +func ListenPacket(network, addr string) (net.PacketConn, error) { + return net.ListenPacket(network, addr) +} diff --git a/plugin/pkg/reuseport/listen_reuseport.go b/plugin/pkg/reuseport/listen_reuseport.go new file mode 100644 index 0000000..71fac3e --- /dev/null +++ b/plugin/pkg/reuseport/listen_reuseport.go @@ -0,0 +1,36 @@ +//go:build go1.11 && (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd) + +package reuseport + +import ( + "context" + "net" + "syscall" + + "github.com/coredns/coredns/plugin/pkg/log" + + "golang.org/x/sys/unix" +) + +func control(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { + log.Warningf("Failed to set SO_REUSEPORT on socket: %s", err) + } + }) + return nil +} + +// Listen announces on the local network address. See net.Listen for more information. +// If SO_REUSEPORT is available it will be set on the socket. +func Listen(network, addr string) (net.Listener, error) { + lc := net.ListenConfig{Control: control} + return lc.Listen(context.Background(), network, addr) +} + +// ListenPacket announces on the local network address. See net.ListenPacket for more information. +// If SO_REUSEPORT is available it will be set on the socket. +func ListenPacket(network, addr string) (net.PacketConn, error) { + lc := net.ListenConfig{Control: control} + return lc.ListenPacket(context.Background(), network, addr) +} diff --git a/plugin/pkg/singleflight/singleflight.go b/plugin/pkg/singleflight/singleflight.go new file mode 100644 index 0000000..e70646c --- /dev/null +++ b/plugin/pkg/singleflight/singleflight.go @@ -0,0 +1,64 @@ +/* +Copyright 2012 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight + +import "sync" + +// call is an in-flight or completed Do call +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +// Group represents a class of work and forms a namespace in which +// units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[uint64]*call // lazily initialized +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +func (g *Group) Do(key uint64, fn func() (interface{}, error)) (interface{}, error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[uint64]*call) + } + if c, ok := g.m[key]; ok { + g.mu.Unlock() + c.wg.Wait() + return c.val, c.err + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + c.val, c.err = fn() + c.wg.Done() + + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() + + return c.val, c.err +} diff --git a/plugin/pkg/singleflight/singleflight_test.go b/plugin/pkg/singleflight/singleflight_test.go new file mode 100644 index 0000000..0e75d41 --- /dev/null +++ b/plugin/pkg/singleflight/singleflight_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2012 Google Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package singleflight + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group + v, err := g.Do(1, func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + var g Group + someErr := errors.New("some error") + v, err := g.Do(1, func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr", err) + } + if v != nil { + t.Errorf("Unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group + c := make(chan string) + var calls int32 + fn := func() (interface{}, error) { + atomic.AddInt32(&calls, 1) + return <-c, nil + } + + const n = 10 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + v, err := g.Do(1, fn) + if err != nil { + t.Errorf("Do error: %v", err) + } + if v.(string) != "bar" { + t.Errorf("Got %q; want %q", v, "bar") + } + wg.Done() + }() + } + time.Sleep(100 * time.Millisecond) // let goroutines above block + c <- "bar" + wg.Wait() + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("Number of calls = %d; want 1", got) + } +} diff --git a/plugin/pkg/tls/tls.go b/plugin/pkg/tls/tls.go new file mode 100644 index 0000000..41eff4b --- /dev/null +++ b/plugin/pkg/tls/tls.go @@ -0,0 +1,149 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "time" +) + +func setTLSDefaults(ctls *tls.Config) { + ctls.MinVersion = tls.VersionTLS12 + ctls.MaxVersion = tls.VersionTLS13 + ctls.CipherSuites = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } +} + +// NewTLSConfigFromArgs returns a TLS config based upon the passed +// in list of arguments. Typically these come straight from the +// Corefile. +// no args +// - creates a Config with no cert and using system CAs +// - use for a client that talks to a server with a public signed cert (CA installed in system) +// - the client will not be authenticated by the server since there is no cert +// +// one arg: the path to CA PEM file +// - creates a Config with no cert using a specific CA +// - use for a client that talks to a server with a private signed cert (CA not installed in system) +// - the client will not be authenticated by the server since there is no cert +// +// two args: path to cert PEM file, the path to private key PEM file +// - creates a Config with a cert, using system CAs to validate the other end +// - use for: +// - a server; or, +// - a client that talks to a server with a public cert and needs certificate-based authentication +// - the other end will authenticate this end via the provided cert +// - the cert of the other end will be verified via system CAs +// +// three args: path to cert PEM file, path to client private key PEM file, path to CA PEM file +// - creates a Config with the cert, using specified CA to validate the other end +// - use for: +// - a server; or, +// - a client that talks to a server with a privately signed cert and needs certificate-based +// authentication +// - the other end will authenticate this end via the provided cert +// - this end will verify the other end's cert using the specified CA +func NewTLSConfigFromArgs(args ...string) (*tls.Config, error) { + var err error + var c *tls.Config + switch len(args) { + case 0: + // No client cert, use system CA + c, err = NewTLSClientConfig("") + case 1: + // No client cert, use specified CA + c, err = NewTLSClientConfig(args[0]) + case 2: + // Client cert, use system CA + c, err = NewTLSConfig(args[0], args[1], "") + case 3: + // Client cert, use specified CA + c, err = NewTLSConfig(args[0], args[1], args[2]) + default: + err = fmt.Errorf("maximum of three arguments allowed for TLS config, found %d", len(args)) + } + if err != nil { + return nil, err + } + return c, nil +} + +// NewTLSConfig returns a TLS config that includes a certificate +// Use for server TLS config or when using a client certificate +// If caPath is empty, system CAs will be used +func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("could not load TLS cert: %s", err) + } + + roots, err := loadRoots(caPath) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots} + setTLSDefaults(tlsConfig) + + return tlsConfig, nil +} + +// NewTLSClientConfig returns a TLS config for a client connection +// If caPath is empty, system CAs will be used +func NewTLSClientConfig(caPath string) (*tls.Config, error) { + roots, err := loadRoots(caPath) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{RootCAs: roots} + setTLSDefaults(tlsConfig) + + return tlsConfig, nil +} + +func loadRoots(caPath string) (*x509.CertPool, error) { + if caPath == "" { + return nil, nil + } + + roots := x509.NewCertPool() + pem, err := os.ReadFile(filepath.Clean(caPath)) + if err != nil { + return nil, fmt.Errorf("error reading %s: %s", caPath, err) + } + ok := roots.AppendCertsFromPEM(pem) + if !ok { + return nil, fmt.Errorf("could not read root certs: %s", err) + } + return roots, nil +} + +// NewHTTPSTransport returns an HTTP transport configured using tls.Config +func NewHTTPSTransport(cc *tls.Config) *http.Transport { + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: cc, + MaxIdleConnsPerHost: 25, + } + + return tr +} diff --git a/plugin/pkg/tls/tls_test.go b/plugin/pkg/tls/tls_test.go new file mode 100644 index 0000000..a5635c1 --- /dev/null +++ b/plugin/pkg/tls/tls_test.go @@ -0,0 +1,127 @@ +package tls + +import ( + "os" + "path/filepath" + "testing" + + "github.com/coredns/coredns/plugin/test" +) + +func getPEMFiles(t *testing.T) (cert, key, ca string) { + tempDir, err := test.WritePEMFiles(t) + if err != nil { + t.Fatalf("Could not write PEM files: %s", err) + } + + cert = filepath.Join(tempDir, "cert.pem") + key = filepath.Join(tempDir, "key.pem") + ca = filepath.Join(tempDir, "ca.pem") + + return +} + +func TestNewTLSConfig(t *testing.T) { + cert, key, ca := getPEMFiles(t) + _, err := NewTLSConfig(cert, key, ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } +} + +func TestNewTLSClientConfig(t *testing.T) { + _, _, ca := getPEMFiles(t) + + _, err := NewTLSClientConfig(ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } +} + +func TestNewTLSConfigFromArgs(t *testing.T) { + cert, key, ca := getPEMFiles(t) + + _, err := NewTLSConfigFromArgs() + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + + c, err := NewTLSConfigFromArgs(ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs == nil { + t.Error("RootCAs should not be nil when one arg passed") + } + + c, err = NewTLSConfigFromArgs(cert, key) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs != nil { + t.Error("RootCAs should be nil when two args passed") + } + if len(c.Certificates) != 1 { + t.Error("Certificates should have a single entry when two args passed") + } + args := []string{cert, key, ca} + c, err = NewTLSConfigFromArgs(args...) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs == nil { + t.Error("RootCAs should not be nil when three args passed") + } + if len(c.Certificates) != 1 { + t.Error("Certificates should have a single entry when three args passed") + } +} + +func TestNewTLSConfigFromArgsWithRoot(t *testing.T) { + cert, key, ca := getPEMFiles(t) + tempDir, err := os.MkdirTemp("", "go-test-pemfiles") + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Error("failed to clean up temporary directory", err) + } + }() + if err != nil { + t.Error("failed to create temporary directory", err) + } + root := tempDir + args := []string{cert, key, ca} + for i := range args { + if !filepath.IsAbs(args[i]) && root != "" { + args[i] = filepath.Join(root, args[i]) + } + } + c, err := NewTLSConfigFromArgs(args...) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + if c.RootCAs == nil { + t.Error("RootCAs should not be nil when three args passed") + } + if len(c.Certificates) != 1 { + t.Error("Certificates should have a single entry when three args passed") + } +} + +func TestNewHTTPSTransport(t *testing.T) { + _, _, ca := getPEMFiles(t) + + cc, err := NewTLSClientConfig(ca) + if err != nil { + t.Errorf("Failed to create TLSConfig: %s", err) + } + + tr := NewHTTPSTransport(cc) + if tr == nil { + t.Errorf("Failed to create https transport with cc") + } + + tr = NewHTTPSTransport(nil) + if tr == nil { + t.Errorf("Failed to create https transport without cc") + } +} diff --git a/plugin/pkg/trace/trace.go b/plugin/pkg/trace/trace.go new file mode 100644 index 0000000..6585d80 --- /dev/null +++ b/plugin/pkg/trace/trace.go @@ -0,0 +1,13 @@ +package trace + +import ( + "github.com/coredns/coredns/plugin" + + ot "github.com/opentracing/opentracing-go" +) + +// Trace holds the tracer and endpoint info +type Trace interface { + plugin.Handler + Tracer() ot.Tracer +} diff --git a/plugin/pkg/transport/transport.go b/plugin/pkg/transport/transport.go new file mode 100644 index 0000000..cdb2c79 --- /dev/null +++ b/plugin/pkg/transport/transport.go @@ -0,0 +1,25 @@ +package transport + +// These transports are supported by CoreDNS. +const ( + DNS = "dns" + TLS = "tls" + QUIC = "quic" + GRPC = "grpc" + HTTPS = "https" + UNIX = "unix" +) + +// Port numbers for the various transports. +const ( + // Port is the default port for DNS + Port = "53" + // TLSPort is the default port for DNS-over-TLS. + TLSPort = "853" + // QUICPort is the default port for DNS-over-QUIC. + QUICPort = "853" + // GRPCPort is the default port for DNS-over-gRPC. + GRPCPort = "443" + // HTTPSPort is the default port for DNS-over-HTTPS. + HTTPSPort = "443" +) diff --git a/plugin/pkg/uniq/uniq.go b/plugin/pkg/uniq/uniq.go new file mode 100644 index 0000000..5f95e41 --- /dev/null +++ b/plugin/pkg/uniq/uniq.go @@ -0,0 +1,46 @@ +// Package uniq keeps track of "thing" that are either "todo" or "done". Multiple +// identical events will only be processed once. +package uniq + +// U keeps track of item to be done. +type U struct { + u map[string]item +} + +type item struct { + state int // either todo or done + f func() error // function to be executed. +} + +// New returns a new initialized U. +func New() U { return U{u: make(map[string]item)} } + +// Set sets function f in U under key. If the key already exists it is not overwritten. +func (u U) Set(key string, f func() error) { + if _, ok := u.u[key]; ok { + return + } + u.u[key] = item{todo, f} +} + +// Unset removes the key. +func (u U) Unset(key string) { + delete(u.u, key) +} + +// ForEach iterates over u and executes f for each element that is 'todo' and sets it to 'done'. +func (u U) ForEach() error { + for k, v := range u.u { + if v.state == todo { + v.f() + } + v.state = done + u.u[k] = v + } + return nil +} + +const ( + todo = 1 + done = 2 +) diff --git a/plugin/pkg/uniq/uniq_test.go b/plugin/pkg/uniq/uniq_test.go new file mode 100644 index 0000000..5d58c92 --- /dev/null +++ b/plugin/pkg/uniq/uniq_test.go @@ -0,0 +1,17 @@ +package uniq + +import "testing" + +func TestForEach(t *testing.T) { + u, i := New(), 0 + u.Set("test", func() error { i++; return nil }) + + u.ForEach() + if i != 1 { + t.Errorf("Failed to executed f for %s", "test") + } + u.ForEach() + if i != 1 { + t.Errorf("Executed f twice instead of once") + } +} diff --git a/plugin/pkg/up/up.go b/plugin/pkg/up/up.go new file mode 100644 index 0000000..649107f --- /dev/null +++ b/plugin/pkg/up/up.go @@ -0,0 +1,83 @@ +// Package up is used to run a function for some duration. If a new function is added while a previous run is +// still ongoing, nothing new will be executed. +package up + +import ( + "sync" + "time" +) + +// Probe is used to run a single Func until it returns true (indicating a target is healthy). If an Func +// is already in progress no new one will be added, i.e. there is always a maximum of 1 checks in flight. +// +// There is a tradeoff to be made in figuring out quickly that an upstream is healthy and not doing much work +// (sending queries) to find that out. Having some kind of exp. backoff here won't help much, because you don't want +// to backoff too much. You then also need random queries to be performed every so often to quickly detect a working +// upstream. In the end we just send a query every 0.5 second to check the upstream. This hopefully strikes a balance +// between getting information about the upstream state quickly and not doing too much work. Note that 0.5s is still an +// eternity in DNS, so we may actually want to shorten it. +type Probe struct { + sync.Mutex + inprogress int + interval time.Duration +} + +// Func is used to determine if a target is alive. If so this function must return nil. +type Func func() error + +// New returns a pointer to an initialized Probe. +func New() *Probe { return &Probe{} } + +// Do will probe target, if a probe is already in progress this is a noop. +func (p *Probe) Do(f Func) { + p.Lock() + if p.inprogress != idle { + p.Unlock() + return + } + p.inprogress = active + interval := p.interval + p.Unlock() + // Passed the lock. Now run f for as long it returns false. If a true is returned + // we return from the goroutine and we can accept another Func to run. + go func() { + i := 1 + for { + if err := f(); err == nil { + break + } + time.Sleep(interval) + p.Lock() + if p.inprogress == stop { + p.Unlock() + return + } + p.Unlock() + i++ + } + + p.Lock() + p.inprogress = idle + p.Unlock() + }() +} + +// Stop stops the probing. +func (p *Probe) Stop() { + p.Lock() + p.inprogress = stop + p.Unlock() +} + +// Start will initialize the probe manager, after which probes can be initiated with Do. +func (p *Probe) Start(interval time.Duration) { + p.Lock() + p.interval = interval + p.Unlock() +} + +const ( + idle = iota + active + stop +) diff --git a/plugin/pkg/up/up_test.go b/plugin/pkg/up/up_test.go new file mode 100644 index 0000000..eeaecea --- /dev/null +++ b/plugin/pkg/up/up_test.go @@ -0,0 +1,40 @@ +package up + +import ( + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestUp(t *testing.T) { + pr := New() + wg := sync.WaitGroup{} + hits := int32(0) + + upfunc := func() error { + atomic.AddInt32(&hits, 1) + // Sleep tiny amount so that our other pr.Do() calls hit the lock. + time.Sleep(3 * time.Millisecond) + wg.Done() + return nil + } + + pr.Start(5 * time.Millisecond) + defer pr.Stop() + + // These functions AddInt32 to the same hits variable, but we only want to wait when + // upfunc finishes, as that only calls Done() on the waitgroup. + upfuncNoWg := func() error { atomic.AddInt32(&hits, 1); return nil } + wg.Add(1) + pr.Do(upfunc) + pr.Do(upfuncNoWg) + pr.Do(upfuncNoWg) + + wg.Wait() + + h := atomic.LoadInt32(&hits) + if h != 1 { + t.Errorf("Expected hits to be %d, got %d", 1, h) + } +} diff --git a/plugin/pkg/upstream/upstream.go b/plugin/pkg/upstream/upstream.go new file mode 100644 index 0000000..b531b70 --- /dev/null +++ b/plugin/pkg/upstream/upstream.go @@ -0,0 +1,35 @@ +// Package upstream abstracts a upstream lookups so that plugins can handle them in an unified way. +package upstream + +import ( + "context" + "fmt" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin/pkg/nonwriter" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Upstream is used to resolve CNAME or other external targets via CoreDNS itself. +type Upstream struct{} + +// New creates a new Upstream to resolve names using the coredns process. +func New() *Upstream { return &Upstream{} } + +// Lookup routes lookups to our selves to make it follow the plugin chain *again*, but with a (possibly) new query. As +// we are doing the query against ourselves again, there is no actual new hop, as such RFC 6891 does not apply and we +// need the EDNS0 option present in the *original* query to be present here too. +func (u *Upstream) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + server, ok := ctx.Value(dnsserver.Key{}).(*dnsserver.Server) + if !ok { + return nil, fmt.Errorf("no full server is running") + } + req := state.NewWithQuestion(name, typ) + + nw := nonwriter.New(state.W) + server.ServeDNS(ctx, nw, req.Req) + + return nw.Msg, nil +} diff --git a/plugin/plugin.go b/plugin/plugin.go new file mode 100644 index 0000000..51f5ba7 --- /dev/null +++ b/plugin/plugin.go @@ -0,0 +1,112 @@ +// Package plugin provides some types and functions common among plugin. +package plugin + +import ( + "context" + "errors" + "fmt" + + "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" + "github.com/prometheus/client_golang/prometheus" +) + +type ( + // Plugin is a middle layer which represents the traditional + // idea of plugin: it chains one Handler to the next by being + // passed the next Handler in the chain. + Plugin func(Handler) Handler + + // Handler is like dns.Handler except ServeDNS may return an rcode + // and/or error. + // + // If ServeDNS writes to the response body, it should return a status + // code. CoreDNS assumes *no* reply has yet been written if the status + // code is one of the following: + // + // * SERVFAIL (dns.RcodeServerFailure) + // + // * REFUSED (dns.RecodeRefused) + // + // * FORMERR (dns.RcodeFormatError) + // + // * NOTIMP (dns.RcodeNotImplemented) + // + // All other response codes signal other handlers above it that the + // response message is already written, and that they should not write + // to it also. + // + // If ServeDNS encounters an error, it should return the error value + // so it can be logged by designated error-handling plugin. + // + // If writing a response after calling another ServeDNS method, the + // returned rcode SHOULD be used when writing the response. + // + // If handling errors after calling another ServeDNS method, the + // returned error value SHOULD be logged or handled accordingly. + // + // Otherwise, return values should be propagated down the plugin + // chain by returning them unchanged. + Handler interface { + ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) + Name() string + } + + // HandlerFunc is a convenience type like dns.HandlerFunc, except + // ServeDNS returns an rcode and an error. See Handler + // documentation for more information. + HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) +) + +// ServeDNS implements the Handler interface. +func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return f(ctx, w, r) +} + +// Name implements the Handler interface. +func (f HandlerFunc) Name() string { return "handlerfunc" } + +// Error returns err with 'plugin/name: ' prefixed to it. +func Error(name string, err error) error { return fmt.Errorf("%s/%s: %s", "plugin", name, err) } + +// NextOrFailure calls next.ServeDNS when next is not nil, otherwise it will return, a ServerFailure and a `no next plugin found` error. +func NextOrFailure(name string, next Handler, ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { // nolint: golint + if next != nil { + if span := ot.SpanFromContext(ctx); span != nil { + child := span.Tracer().StartSpan(next.Name(), ot.ChildOf(span.Context())) + defer child.Finish() + ctx = ot.ContextWithSpan(ctx, child) + } + return next.ServeDNS(ctx, w, r) + } + + return dns.RcodeServerFailure, Error(name, errors.New("no next plugin found")) +} + +// ClientWrite returns true if the response has been written to the client. +// Each plugin to adhere to this protocol. +func ClientWrite(rcode int) bool { + switch rcode { + case dns.RcodeServerFailure: + fallthrough + case dns.RcodeRefused: + fallthrough + case dns.RcodeFormatError: + fallthrough + case dns.RcodeNotImplemented: + return false + } + return true +} + +// Namespace is the namespace used for the metrics. +const Namespace = "coredns" + +// TimeBuckets is based on Prometheus client_golang prometheus.DefBuckets +var TimeBuckets = prometheus.ExponentialBuckets(0.00025, 2, 16) // from 0.25ms to 8 seconds + +// SlimTimeBuckets is low cardinality set of duration buckets. +var SlimTimeBuckets = prometheus.ExponentialBuckets(0.00025, 10, 5) // from 0.25ms to 2.5 seconds + +// ErrOnce is returned when a plugin doesn't support multiple setups per server. +var ErrOnce = errors.New("this plugin can only be used once per Server Block") diff --git a/plugin/pprof/README.md b/plugin/pprof/README.md new file mode 100644 index 0000000..c63d152 --- /dev/null +++ b/plugin/pprof/README.md @@ -0,0 +1,74 @@ +# pprof + +## Name + +*pprof* - publishes runtime profiling data at endpoints under `/debug/pprof`. + +## Description + +You can visit `/debug/pprof` on your site for an index of the available endpoints. By default it +will listen on localhost:6053. + +This is a debugging tool. Certain requests (such as collecting execution traces) can be slow. If +you use pprof on a live server, consider restricting access or enabling it only temporarily. + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ txt +pprof [ADDRESS] +~~~ + +Optionally pprof takes an address; the default is `localhost:6053`. + +An extra option can be set with this extended syntax: + +~~~ txt +pprof [ADDRESS] { + block [RATE] +} +~~~ + +* `block` option enables block profiling, **RATE** defaults to 1. **RATE** must be a positive value. + See [Diagnostics, chapter profiling](https://golang.org/doc/diagnostics.html) and + [runtime.SetBlockProfileRate](https://golang.org/pkg/runtime/#SetBlockProfileRate) for what block + profiling entails. + +## Examples + +Enable a pprof endpoint: + +~~~ +. { + pprof +} +~~~ + +And use the pprof tool to get statistics: `go tool pprof http://localhost:6053`. + +Listen on an alternate address: + +~~~ txt +. { + pprof 10.9.8.7:6060 +} +~~~ + +Listen on an all addresses on port 6060, and enable block profiling + +~~~ txt +. { + pprof :6060 { + block + } +} +~~~ + +## See Also + +See [Go's pprof documentation](https://golang.org/pkg/net/http/pprof/) and [Profiling Go +Programs](https://blog.golang.org/profiling-go-programs). + +See [runtime.SetBlockProfileRate](https://golang.org/pkg/runtime/#SetBlockProfileRate) for +background on block profiling. diff --git a/plugin/pprof/log_test.go b/plugin/pprof/log_test.go new file mode 100644 index 0000000..7e2c252 --- /dev/null +++ b/plugin/pprof/log_test.go @@ -0,0 +1,5 @@ +package pprof + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/pprof/pprof.go b/plugin/pprof/pprof.go new file mode 100644 index 0000000..822e6e2 --- /dev/null +++ b/plugin/pprof/pprof.go @@ -0,0 +1,60 @@ +// Package pprof implements a debug endpoint for getting profiles using the +// go pprof tooling. +package pprof + +import ( + "net" + "net/http" + pp "net/http/pprof" + "runtime" + + "github.com/coredns/coredns/plugin/pkg/reuseport" +) + +type handler struct { + addr string + rateBloc int + ln net.Listener + mux *http.ServeMux +} + +func (h *handler) Startup() error { + // Reloading the plugin without changing the listening address results + // in an error unless we reuse the port because Startup is called for + // new handlers before Shutdown is called for the old ones. + ln, err := reuseport.Listen("tcp", h.addr) + if err != nil { + log.Errorf("Failed to start pprof handler: %s", err) + return err + } + + h.ln = ln + + h.mux = http.NewServeMux() + h.mux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) { + http.Redirect(rw, req, path+"/", http.StatusFound) + }) + h.mux.HandleFunc(path+"/", pp.Index) + h.mux.HandleFunc(path+"/cmdline", pp.Cmdline) + h.mux.HandleFunc(path+"/profile", pp.Profile) + h.mux.HandleFunc(path+"/symbol", pp.Symbol) + h.mux.HandleFunc(path+"/trace", pp.Trace) + + runtime.SetBlockProfileRate(h.rateBloc) + + go func() { + http.Serve(h.ln, h.mux) + }() + return nil +} + +func (h *handler) Shutdown() error { + if h.ln != nil { + return h.ln.Close() + } + return nil +} + +const ( + path = "/debug/pprof" +) diff --git a/plugin/pprof/setup.go b/plugin/pprof/setup.go new file mode 100644 index 0000000..3505b5d --- /dev/null +++ b/plugin/pprof/setup.go @@ -0,0 +1,65 @@ +package pprof + +import ( + "net" + "strconv" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("pprof") + +const defaultAddr = "localhost:6053" + +func init() { plugin.Register("pprof", setup) } + +func setup(c *caddy.Controller) error { + h := &handler{addr: defaultAddr} + + i := 0 + for c.Next() { + if i > 0 { + return plugin.Error("pprof", plugin.ErrOnce) + } + i++ + + args := c.RemainingArgs() + if len(args) == 1 { + h.addr = args[0] + _, _, e := net.SplitHostPort(h.addr) + if e != nil { + return plugin.Error("pprof", c.Errf("%v", e)) + } + } + + if len(args) > 1 { + return plugin.Error("pprof", c.ArgErr()) + } + + for c.NextBlock() { + switch c.Val() { + case "block": + args := c.RemainingArgs() + if len(args) > 1 { + return plugin.Error("pprof", c.ArgErr()) + } + h.rateBloc = 1 + if len(args) > 0 { + t, err := strconv.Atoi(args[0]) + if err != nil { + return plugin.Error("pprof", c.Errf("property '%s' invalid integer value '%v'", "block", args[0])) + } + h.rateBloc = t + } + default: + return plugin.Error("pprof", c.Errf("unknown property '%s'", c.Val())) + } + } + } + + c.OnStartup(h.Startup) + c.OnShutdown(h.Shutdown) + return nil +} diff --git a/plugin/pprof/setup_test.go b/plugin/pprof/setup_test.go new file mode 100644 index 0000000..500a400 --- /dev/null +++ b/plugin/pprof/setup_test.go @@ -0,0 +1,44 @@ +package pprof + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestPProf(t *testing.T) { + tests := []struct { + input string + shouldErr bool + }{ + {`pprof`, false}, + {`pprof 1.2.3.4:1234`, false}, + {`pprof :1234`, false}, + {`pprof :1234 -1`, true}, + {`pprof { + }`, false}, + {`pprof /foo`, true}, + {`pprof { + a b + }`, true}, + {`pprof { block + }`, false}, + {`pprof :1234 { + block 20 + }`, false}, + {`pprof { + block 20 30 + }`, true}, + {`pprof + pprof`, true}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + } + } +} diff --git a/plugin/ready/README.md b/plugin/ready/README.md new file mode 100644 index 0000000..d2e430d --- /dev/null +++ b/plugin/ready/README.md @@ -0,0 +1,58 @@ +# ready + +## Name + +*ready* - enables a readiness check HTTP endpoint. + +## Description + +By enabling *ready* an HTTP endpoint on port 8181 will return 200 OK, when all plugins that are able +to signal readiness have done so. If some are not ready yet the endpoint will return a 503 with the +body containing the list of plugins that are not ready. Once a plugin has signaled it is ready it +will not be queried again. + +Each Server Block that enables the *ready* plugin will have the plugins *in that server block* +report readiness into the /ready endpoint that runs on the same port. This also means that the +*same* plugin with different configurations (in potentially *different* Server Blocks) will have +their readiness reported as the union of their respective readinesses. + +## Syntax + +~~~ +ready [ADDRESS] +~~~ + +*ready* optionally takes an address; the default is `:8181`. The path is fixed to `/ready`. The +readiness endpoint returns a 200 response code and the word "OK" when this server is ready. It +returns a 503 otherwise *and* the list of plugins that are not ready. + +## Plugins + +Any plugin wanting to signal readiness will need to implement the `ready.Readiness` interface by +implementing a method `Ready() bool` that returns true when the plugin is ready and false otherwise. + +## Examples + +Let *ready* report readiness for both the `.` and `example.org` servers (assuming the *whois* +plugin also exports readiness): + +~~~ txt +. { + ready + erratic +} + +example.org { + ready + whoami +} + +~~~ + +Run *ready* on a different port. + +~~~ txt +. { + ready localhost:8091 +} +~~~ diff --git a/plugin/ready/list.go b/plugin/ready/list.go new file mode 100644 index 0000000..c246287 --- /dev/null +++ b/plugin/ready/list.go @@ -0,0 +1,56 @@ +package ready + +import ( + "sort" + "strings" + "sync" +) + +// list is a structure that holds the plugins that signals readiness for this server block. +type list struct { + sync.RWMutex + rs []Readiness + names []string +} + +// Reset resets l +func (l *list) Reset() { + l.Lock() + defer l.Unlock() + l.rs = nil + l.names = nil +} + +// Append adds a new readiness to l. +func (l *list) Append(r Readiness, name string) { + l.Lock() + defer l.Unlock() + l.rs = append(l.rs, r) + l.names = append(l.names, name) +} + +// Ready return true when all plugins ready, if the returned value is false the string +// contains a comma separated list of plugins that are not ready. +func (l *list) Ready() (bool, string) { + l.RLock() + defer l.RUnlock() + ok := true + s := []string{} + for i, r := range l.rs { + if r == nil { + continue + } + if !r.Ready() { + ok = false + s = append(s, l.names[i]) + } else { + // if ok, this plugin is ready and will not be queried anymore. + l.rs[i] = nil + } + } + if ok { + return true, "" + } + sort.Strings(s) + return false, strings.Join(s, ",") +} diff --git a/plugin/ready/readiness.go b/plugin/ready/readiness.go new file mode 100644 index 0000000..7aca5df --- /dev/null +++ b/plugin/ready/readiness.go @@ -0,0 +1,7 @@ +package ready + +// The Readiness interface needs to be implemented by each plugin willing to provide a readiness check. +type Readiness interface { + // Ready is called by ready to see whether the plugin is ready. + Ready() bool +} diff --git a/plugin/ready/ready.go b/plugin/ready/ready.go new file mode 100644 index 0000000..2002e4a --- /dev/null +++ b/plugin/ready/ready.go @@ -0,0 +1,81 @@ +// Package ready is used to signal readiness of the CoreDNS process. Once all +// plugins have called in the plugin will signal readiness by returning a 200 +// OK on the HTTP handler (on port 8181). If not ready yet, the handler will +// return a 503. +package ready + +import ( + "io" + "net" + "net/http" + "sync" + + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/reuseport" + "github.com/coredns/coredns/plugin/pkg/uniq" +) + +var ( + log = clog.NewWithPlugin("ready") + plugins = &list{} + uniqAddr = uniq.New() +) + +type ready struct { + Addr string + + sync.RWMutex + ln net.Listener + done bool + mux *http.ServeMux +} + +func (rd *ready) onStartup() error { + ln, err := reuseport.Listen("tcp", rd.Addr) + if err != nil { + return err + } + + rd.Lock() + rd.ln = ln + rd.mux = http.NewServeMux() + rd.done = true + rd.Unlock() + + rd.mux.HandleFunc("/ready", func(w http.ResponseWriter, _ *http.Request) { + rd.Lock() + defer rd.Unlock() + if !rd.done { + w.WriteHeader(http.StatusServiceUnavailable) + io.WriteString(w, "Shutting down") + return + } + ok, todo := plugins.Ready() + if ok { + w.WriteHeader(http.StatusOK) + io.WriteString(w, http.StatusText(http.StatusOK)) + return + } + log.Infof("Still waiting on: %q", todo) + w.WriteHeader(http.StatusServiceUnavailable) + io.WriteString(w, todo) + }) + + go func() { http.Serve(rd.ln, rd.mux) }() + + return nil +} + +func (rd *ready) onFinalShutdown() error { + rd.Lock() + defer rd.Unlock() + if !rd.done { + return nil + } + + uniqAddr.Unset(rd.Addr) + + rd.ln.Close() + rd.done = false + return nil +} diff --git a/plugin/ready/ready_test.go b/plugin/ready/ready_test.go new file mode 100644 index 0000000..414541c --- /dev/null +++ b/plugin/ready/ready_test.go @@ -0,0 +1,69 @@ +package ready + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/coredns/coredns/plugin/erratic" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func init() { clog.Discard() } + +func TestReady(t *testing.T) { + rd := &ready{Addr: ":0"} + e := &erratic.Erratic{} + plugins.Append(e, "erratic") + + if err := rd.onStartup(); err != nil { + t.Fatalf("Unable to startup the readiness server: %v", err) + } + + defer rd.onFinalShutdown() + + address := fmt.Sprintf("http://%s/ready", rd.ln.Addr().String()) + + response, err := http.Get(address) + if err != nil { + t.Fatalf("Unable to query %s: %v", address, err) + } + if response.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Invalid status code: expecting %d, got %d", 503, response.StatusCode) + } + response.Body.Close() + + // make it ready by giving erratic 3 queries. + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + e.ServeDNS(context.TODO(), &test.ResponseWriter{}, m) + e.ServeDNS(context.TODO(), &test.ResponseWriter{}, m) + e.ServeDNS(context.TODO(), &test.ResponseWriter{}, m) + + response, err = http.Get(address) + if err != nil { + t.Fatalf("Unable to query %s: %v", address, err) + } + if response.StatusCode != http.StatusOK { + t.Errorf("Invalid status code: expecting %d, got %d", 200, response.StatusCode) + } + response.Body.Close() + + // make erratic not-ready by giving it more queries, this should not change the process readiness + e.ServeDNS(context.TODO(), &test.ResponseWriter{}, m) + e.ServeDNS(context.TODO(), &test.ResponseWriter{}, m) + e.ServeDNS(context.TODO(), &test.ResponseWriter{}, m) + + response, err = http.Get(address) + if err != nil { + t.Fatalf("Unable to query %s: %v", address, err) + } + if response.StatusCode != http.StatusOK { + t.Errorf("Invalid status code: expecting %d, got %d", 200, response.StatusCode) + } + response.Body.Close() +} diff --git a/plugin/ready/setup.go b/plugin/ready/setup.go new file mode 100644 index 0000000..e5657f6 --- /dev/null +++ b/plugin/ready/setup.go @@ -0,0 +1,73 @@ +package ready + +import ( + "net" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("ready", setup) } + +func setup(c *caddy.Controller) error { + addr, err := parse(c) + if err != nil { + return plugin.Error("ready", err) + } + rd := &ready{Addr: addr} + + uniqAddr.Set(addr, rd.onStartup) + c.OnStartup(func() error { uniqAddr.Set(addr, rd.onStartup); return nil }) + c.OnRestartFailed(func() error { uniqAddr.Set(addr, rd.onStartup); return nil }) + + c.OnStartup(func() error { return uniqAddr.ForEach() }) + c.OnRestartFailed(func() error { return uniqAddr.ForEach() }) + + c.OnStartup(func() error { + plugins.Reset() + for _, p := range dnsserver.GetConfig(c).Handlers() { + if r, ok := p.(Readiness); ok { + plugins.Append(r, p.Name()) + } + } + return nil + }) + c.OnRestartFailed(func() error { + for _, p := range dnsserver.GetConfig(c).Handlers() { + if r, ok := p.(Readiness); ok { + plugins.Append(r, p.Name()) + } + } + return nil + }) + + c.OnRestart(rd.onFinalShutdown) + c.OnFinalShutdown(rd.onFinalShutdown) + + return nil +} + +func parse(c *caddy.Controller) (string, error) { + addr := ":8181" + i := 0 + for c.Next() { + if i > 0 { + return "", plugin.ErrOnce + } + i++ + args := c.RemainingArgs() + + switch len(args) { + case 0: + case 1: + addr = args[0] + if _, _, e := net.SplitHostPort(addr); e != nil { + return "", e + } + default: + return "", c.ArgErr() + } + } + return addr, nil +} diff --git a/plugin/ready/setup_test.go b/plugin/ready/setup_test.go new file mode 100644 index 0000000..1dd0d4a --- /dev/null +++ b/plugin/ready/setup_test.go @@ -0,0 +1,34 @@ +package ready + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupReady(t *testing.T) { + tests := []struct { + input string + shouldErr bool + }{ + {`ready`, false}, + {`ready localhost:1234`, false}, + {`ready localhost:1234 b`, true}, + {`ready bla`, true}, + {`ready bla bla`, true}, + } + + for i, test := range tests { + _, err := parse(caddy.NewTestController("dns", test.input)) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + } + } +} diff --git a/plugin/register.go b/plugin/register.go new file mode 100644 index 0000000..16090ff --- /dev/null +++ b/plugin/register.go @@ -0,0 +1,11 @@ +package plugin + +import "github.com/coredns/caddy" + +// Register registers your plugin with CoreDNS and allows it to be called when the server is running. +func Register(name string, action caddy.SetupFunc) { + caddy.RegisterPlugin(name, caddy.Plugin{ + ServerType: "dns", + Action: action, + }) +} diff --git a/plugin/reload/README.md b/plugin/reload/README.md new file mode 100644 index 0000000..b4dff55 --- /dev/null +++ b/plugin/reload/README.md @@ -0,0 +1,108 @@ +# reload + +## Name + +*reload* - allows automatic reload of a changed Corefile. + +## Description + +This plugin allows automatic reload of a changed _Corefile_. +To enable automatic reloading of _zone file_ changes, use the `auto` plugin. + +This plugin periodically checks if the Corefile has changed by reading +it and calculating its SHA512 checksum. If the file has changed, it reloads +CoreDNS with the new Corefile. This eliminates the need to send a SIGHUP +or SIGUSR1 after changing the Corefile. + +The reloads are graceful - you should not see any loss of service when the +reload happens. Even if the new Corefile has an error, CoreDNS will continue +to run the old config and an error message will be printed to the log. But see +the Bugs section for failure modes. + +In some environments (for example, Kubernetes), there may be many CoreDNS +instances that started very near the same time and all share a common +Corefile. To prevent these all from reloading at the same time, some +jitter is added to the reload check interval. This is jitter from the +perspective of multiple CoreDNS instances; each instance still checks on a +regular interval, but all of these instances will have their reloads spread +out across the jitter duration. This isn't strictly necessary given that the +reloads are graceful, and can be disabled by setting the jitter to `0s`. + +Jitter is re-calculated whenever the Corefile is reloaded. + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ txt +reload [INTERVAL] [JITTER] +~~~ + +The plugin will check for changes every **INTERVAL**, subject to +/- the **JITTER** duration. + +* **INTERVAL** and **JITTER** are Golang [durations](https://golang.org/pkg/time/#ParseDuration). + The default **INTERVAL** is 30s, default **JITTER** is 15s, the minimal value for **INTERVAL** + is 2s, and for **JITTER** it is 1s. If **JITTER** is more than half of **INTERVAL**, it will be + set to half of **INTERVAL** + +## Examples + +Check with the default intervals: + +~~~ corefile +. { + reload + erratic +} +~~~ + +Check every 10 seconds (jitter is automatically set to 10 / 2 = 5 in this case): + +~~~ corefile +. { + reload 10s + erratic +} +~~~ + +## Bugs + +The reload happens without data loss (i.e. DNS queries keep flowing), but there is a corner case +where the reload fails, and you lose functionality. Consider the following Corefile: + +~~~ txt +. { + health :8080 + whoami +} +~~~ + +CoreDNS starts and serves health from :8080. Now you change `:8080` to `:443` not knowing a process +is already listening on that port. The process reloads and performs the following steps: + +1. close the listener on 8080 +2. reload and parse the config again +3. fail to start a new listener on 443 +4. fail loading the new Corefile, abort and keep using the old process + +After the aborted attempt to reload we are left with the old processes running, but the listener is +closed in step 1; so the health endpoint is broken. The same can happen in the prometheus plugin. + +In general be careful with assigning new port and expecting reload to work fully. + +In CoreDNS v1.6.0 and earlier any `import` statements are not discovered by this plugin. +This means if any of these imported files changes the *reload* plugin is ignorant of that fact. +CoreDNS v1.7.0 and later does parse the Corefile and supports detecting changes in imported files. + +## Metrics + + If monitoring is enabled (via the *prometheus* plugin) then the following metric is exported: + +* `coredns_reload_failed_total{}` - counts the number of failed reload attempts. +* `coredns_reload_version_info{hash, value}` - record the hash value during reload. + +Currently the type of `hash` is "sha512", the `value` is the returned hash value. + +## See Also + +See coredns-import(7) and corefile(5). diff --git a/plugin/reload/log_test.go b/plugin/reload/log_test.go new file mode 100644 index 0000000..2f6598d --- /dev/null +++ b/plugin/reload/log_test.go @@ -0,0 +1,5 @@ +package reload + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/reload/metrics.go b/plugin/reload/metrics.go new file mode 100644 index 0000000..7224791 --- /dev/null +++ b/plugin/reload/metrics.go @@ -0,0 +1,26 @@ +package reload + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Metrics for the reload plugin +var ( + // failedCount is the counter of the number of failed reload attempts. + failedCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "reload", + Name: "failed_total", + Help: "Counter of the number of failed reload attempts.", + }) + // reloadInfo is record the hash value during reload. + reloadInfo = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: plugin.Namespace, + Subsystem: "reload", + Name: "version_info", + Help: "A metric with a constant '1' value labeled by hash, and value which type of hash generated.", + }, []string{"hash", "value"}) +) diff --git a/plugin/reload/reload.go b/plugin/reload/reload.go new file mode 100644 index 0000000..917681c --- /dev/null +++ b/plugin/reload/reload.go @@ -0,0 +1,128 @@ +// Package reload periodically checks if the Corefile has changed, and reloads if so. +package reload + +import ( + "bytes" + "crypto/sha512" + "encoding/hex" + "encoding/json" + "sync" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/caddy/caddyfile" + + "github.com/prometheus/client_golang/prometheus" +) + +const ( + unused = 0 + maybeUsed = 1 + used = 2 +) + +type reload struct { + dur time.Duration + u int + mtx sync.RWMutex + quit chan bool +} + +func (r *reload) setUsage(u int) { + r.mtx.Lock() + defer r.mtx.Unlock() + r.u = u +} + +func (r *reload) usage() int { + r.mtx.RLock() + defer r.mtx.RUnlock() + return r.u +} + +func (r *reload) setInterval(i time.Duration) { + r.mtx.Lock() + defer r.mtx.Unlock() + r.dur = i +} + +func (r *reload) interval() time.Duration { + r.mtx.RLock() + defer r.mtx.RUnlock() + return r.dur +} + +func parse(corefile caddy.Input) ([]byte, error) { + serverBlocks, err := caddyfile.Parse(corefile.Path(), bytes.NewReader(corefile.Body()), nil) + if err != nil { + return nil, err + } + return json.Marshal(serverBlocks) +} + +func hook(event caddy.EventName, info interface{}) error { + if event != caddy.InstanceStartupEvent { + return nil + } + // if reload is removed from the Corefile, then the hook + // is still registered but setup is never called again + // so we need a flag to tell us not to reload + if r.usage() == unused { + return nil + } + + // this should be an instance. ok to panic if not + instance := info.(*caddy.Instance) + parsedCorefile, err := parse(instance.Caddyfile()) + if err != nil { + return err + } + + sha512sum := sha512.Sum512(parsedCorefile) + log.Infof("Running configuration SHA512 = %x\n", sha512sum) + + go func() { + tick := time.NewTicker(r.interval()) + defer tick.Stop() + + for { + select { + case <-tick.C: + corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType()) + if err != nil { + continue + } + parsedCorefile, err := parse(corefile) + if err != nil { + log.Warningf("Corefile parse failed: %s", err) + continue + } + s := sha512.Sum512(parsedCorefile) + if s != sha512sum { + reloadInfo.Delete(prometheus.Labels{"hash": "sha512", "value": hex.EncodeToString(sha512sum[:])}) + // Let not try to restart with the same file, even though it is wrong. + sha512sum = s + // now lets consider that plugin will not be reload, unless appear in next config file + // change status of usage will be reset in setup if the plugin appears in config file + r.setUsage(maybeUsed) + _, err := instance.Restart(corefile) + reloadInfo.WithLabelValues("sha512", hex.EncodeToString(sha512sum[:])).Set(1) + if err != nil { + log.Errorf("Corefile changed but reload failed: %s", err) + failedCount.Add(1) + continue + } + // we are done, if the plugin was not set used, then it is not. + if r.usage() == maybeUsed { + r.setUsage(unused) + } + return + } + case <-r.quit: + return + } + } + }() + + return nil +} diff --git a/plugin/reload/setup.go b/plugin/reload/setup.go new file mode 100644 index 0000000..6df3234 --- /dev/null +++ b/plugin/reload/setup.go @@ -0,0 +1,87 @@ +package reload + +import ( + "fmt" + "math/rand" + "sync" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("reload") + +func init() { plugin.Register("reload", setup) } + +// the info reload is global to all application, whatever number of reloads. +// it is used to transmit data between Setup and start of the hook called 'onInstanceStartup' +// channel for QUIT is never changed in purpose. +// WARNING: this data may be unsync after an invalid attempt of reload Corefile. +var ( + r = reload{dur: defaultInterval, u: unused, quit: make(chan bool)} + once, shutOnce sync.Once +) + +func setup(c *caddy.Controller) error { + c.Next() // 'reload' + args := c.RemainingArgs() + + if len(args) > 2 { + return plugin.Error("reload", c.ArgErr()) + } + + i := defaultInterval + if len(args) > 0 { + d, err := time.ParseDuration(args[0]) + if err != nil { + return plugin.Error("reload", err) + } + i = d + } + if i < minInterval { + return plugin.Error("reload", fmt.Errorf("interval value must be greater or equal to %v", minInterval)) + } + + j := defaultJitter + if len(args) > 1 { + d, err := time.ParseDuration(args[1]) + if err != nil { + return plugin.Error("reload", err) + } + j = d + } + if j < minJitter { + return plugin.Error("reload", fmt.Errorf("jitter value must be greater or equal to %v", minJitter)) + } + + if j > i/2 { + j = i / 2 + } + + jitter := time.Duration(rand.Int63n(j.Nanoseconds()) - (j.Nanoseconds() / 2)) + i = i + jitter + + // prepare info for next onInstanceStartup event + r.setInterval(i) + r.setUsage(used) + once.Do(func() { + caddy.RegisterEventHook("reload", hook) + }) + // re-register on finalShutDown as the instance most-likely will be changed + shutOnce.Do(func() { + c.OnFinalShutdown(func() error { + r.quit <- true + return nil + }) + }) + return nil +} + +const ( + minJitter = 1 * time.Second + minInterval = 2 * time.Second + defaultInterval = 30 * time.Second + defaultJitter = 15 * time.Second +) diff --git a/plugin/reload/setup_test.go b/plugin/reload/setup_test.go new file mode 100644 index 0000000..5450fae --- /dev/null +++ b/plugin/reload/setup_test.go @@ -0,0 +1,51 @@ +package reload + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupReload(t *testing.T) { + c := caddy.NewTestController("dns", `reload`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `reload 10s`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `reload 10s 2s`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `reload foo`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `reload 10s foo`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `reload 10s 5s foo`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } + c = caddy.NewTestController("dns", `reload 1s`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } + c = caddy.NewTestController("dns", `reload 0s`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } + c = caddy.NewTestController("dns", `reload 3s 0.5s`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } +} diff --git a/plugin/rewrite/README.md b/plugin/rewrite/README.md new file mode 100644 index 0000000..895ef63 --- /dev/null +++ b/plugin/rewrite/README.md @@ -0,0 +1,509 @@ +# rewrite + +## Name + +*rewrite* - performs internal message rewriting. + +## Description + +Rewrites are invisible to the client. There are simple rewrites (fast) and complex rewrites +(slower), but they're powerful enough to accommodate most dynamic back-end applications. + +## Syntax + +A simplified/easy-to-digest syntax for *rewrite* is... +~~~ +rewrite [continue|stop] FIELD [TYPE] [(FROM TO)|TTL] [OPTIONS] +~~~ + +* **FIELD** indicates what part of the request/response is being re-written. + + * `type` - the type field of the request will be rewritten. FROM/TO must be a DNS record type (`A`, `MX`, etc.); +e.g., to rewrite ANY queries to HINFO, use `rewrite type ANY HINFO`. + * `name` - the query name in the _request_ is rewritten; by default this is a full match of the + name, e.g., `rewrite name example.net example.org`. Other match types are supported, see the **Name Field Rewrites** section below. + * `class` - the class of the message will be rewritten. FROM/TO must be a DNS class type (`IN`, `CH`, or `HS`); e.g., to rewrite CH queries to IN use `rewrite class CH IN`. + * `edns0` - an EDNS0 option can be appended to the request as described below in the **EDNS0 Options** section. + * `ttl` - the TTL value in the _response_ is rewritten. + * `cname` - the CNAME target if the response has a CNAME record + * `rcode` - the response code (RCODE) value in the _response_ is rewritten. + +* **TYPE** this optional element can be specified for a `name` or `ttl` field. + If not given type `exact` will be assumed. If options should be specified the + type must be given. +* **FROM** is the name (exact, suffix, prefix, substring, or regex) or type to match +* **TO** is the destination name or type to rewrite to +* **TTL** is the number of seconds to set the TTL value to (only for field `ttl`) + +* **OPTIONS** + + for field `name` further options are possible controlling the response rewrites. + All name matching types support the following options + + * `answer auto` - the names in the _response_ is rewritten in a best effort manner. + * `answer name FROM TO` - the query name in the _response_ is rewritten matching the from regex pattern. + * `answer value FROM TO` - the names in the _response_ is rewritten matching the from regex pattern. + + See below in the **Response Rewrites** section for further details. + +If you specify multiple rules and an incoming query matches multiple rules, the rewrite +will behave as follows: + + * `continue` will continue applying the next rule in the rule list. + * `stop` will consider the current rule the last rule and will not continue. The default behaviour is `stop` + +## Examples + +### Name Field Rewrites + +The `rewrite` plugin offers the ability to match the name in the question section of +a DNS request. The match could be exact, a substring match, or based on a prefix, suffix, or regular +expression. If the newly used name is not a legal domain name, the plugin returns an error to the +client. + +The syntax for name rewriting is as follows: + +``` +rewrite [continue|stop] name [exact|prefix|suffix|substring|regex] STRING STRING [OPTIONS] +``` + +The match type, e.g., `exact`, `substring`, etc., triggers rewrite: + +* **exact** (default): on an exact match of the name in the question section of a request +* **substring**: on a partial match of the name in the question section of a request +* **prefix**: when the name begins with the matching string +* **suffix**: when the name ends with the matching string +* **regex**: when the name in the question section of a request matches a regular expression + +If the match type is omitted, the `exact` match type is assumed. If OPTIONS are +given, the type must be specified. + +The following instruction allows rewriting names in the query that +contain the substring `service.us-west-1.example.org`: + +``` +rewrite name substring service.us-west-1.example.org service.us-west-1.consul +``` + +Thus: + +* Incoming Request Name: `ftp.service.us-west-1.example.org` +* Rewritten Request Name: `ftp.service.us-west-1.consul` + +The following instruction uses regular expressions. Names in requests +matching the regular expression `(.*)-(us-west-1)\.example\.org` are replaced with +`{1}.service.{2}.consul`, where `{1}` and `{2}` are regular expression match groups. + +``` +rewrite name regex (.*)-(us-west-1)\.example\.org {1}.service.{2}.consul +``` + +Thus: + +* Incoming Request Name: `ftp-us-west-1.example.org` +* Rewritten Request Name: `ftp.service.us-west-1.consul` + +The following example rewrites the `schmoogle.com` suffix to `google.com`. + +~~~ +rewrite name suffix .schmoogle.com. .google.com. +~~~ + +### Response Rewrites + +When rewriting incoming DNS requests' names (field `name`), CoreDNS re-writes +the `QUESTION SECTION` +section of the requests. It may be necessary to rewrite the `ANSWER SECTION` of the +requests, because some DNS resolvers treat mismatches between the `QUESTION SECTION` +and `ANSWER SECTION` as a man-in-the-middle attack (MITM). + +For example, a user tries to resolve `ftp-us-west-1.coredns.rocks`. The +CoreDNS configuration file has the following rule: + +``` +rewrite name regex (.*)-(us-west-1)\.coredns\.rocks {1}.service.{2}.consul +``` + +CoreDNS rewrote the request from `ftp-us-west-1.coredns.rocks` to +`ftp.service.us-west-1.consul` and ultimately resolved it to 3 records. +The resolved records, in the `ANSWER SECTION` below, were not from `coredns.rocks`, but +rather from `service.us-west-1.consul`. + + +``` +$ dig @10.1.1.1 ftp-us-west-1.coredns.rocks + +;; QUESTION SECTION: +;ftp-us-west-1.coredns.rocks. IN A + +;; ANSWER SECTION: +ftp.service.us-west-1.consul. 0 IN A 10.10.10.10 +ftp.service.us-west-1.consul. 0 IN A 10.20.20.20 +ftp.service.us-west-1.consul. 0 IN A 10.30.30.30 +``` + +The above is a mismatch between the question asked and the answer provided. + +There are three possibilities to specify an answer rewrite: +- A rewrite can request a best effort answer rewrite by adding the option `answer auto`. +- A rewrite may specify a dedicated regex based response name rewrite with the + `answer name FROM TO` option. +- A regex based rewrite of record values like `CNAME`, `SRV`, etc, can be requested by + an `answer value FROM TO` option. + +Hereby FROM/TO follow the rules for the `regex` name rewrite syntax. + +#### Auto Response Name Rewrite + +The following configuration snippet allows for rewriting of the +`ANSWER SECTION` according to the rewrite of the `QUESTION SECTION`: + +``` + rewrite stop { + name suffix .coredns.rocks .service.consul answer auto + } +``` + +Any occurrence of the rewritten question in the answer is mapped +back to the original value before the rewrite. + +Please note that answers for rewrites of type `exact` are always rewritten. +For a `suffix` name rule `auto` leads to a reverse suffix response rewrite, +exchanging FROM and TO from the rewrite request. + +#### Explicit Response Name Rewrite + +The following configuration snippet allows for rewriting of the +`ANSWER SECTION`, provided that the `QUESTION SECTION` was rewritten: + +``` + rewrite stop { + name regex (.*)-(us-west-1)\.coredns\.rocks {1}.service.{2}.consul + answer name (.*)\.service\.(us-west-1)\.consul {1}-{2}.coredns.rocks + } +``` + +Now, the `ANSWER SECTION` matches the `QUESTION SECTION`: + +``` +$ dig @10.1.1.1 ftp-us-west-1.coredns.rocks + +;; QUESTION SECTION: +;ftp-us-west-1.coredns.rocks. IN A + +;; ANSWER SECTION: +ftp-us-west-1.coredns.rocks. 0 IN A 10.10.10.10 +ftp-us-west-1.coredns.rocks. 0 IN A 10.20.20.20 +ftp-us-west-1.coredns.rocks. 0 IN A 10.30.30.30 +``` + +#### Rewriting other Response Values + +It is also possible to rewrite other values returned in the DNS response records +(e.g. the server names returned in `SRV` and `MX` records). This can be enabled by adding +the `answer value FROM TO` option to a name rule as specified below. `answer value` takes a +regular expression and a rewrite name as parameters and works in the same way as the +`answer name` rule. + +Note that names in the `AUTHORITY SECTION` and `ADDITIONAL SECTION` will also be +rewritten following the specified rules. The names returned by the following +record types: `CNAME`, `DNAME`, `SOA`, `SRV`, `MX`, `NAPTR`, `NS`, `PTR` will be rewritten +if the `answer value` rule is specified. + +The syntax for the rewrite of DNS request and response is as follows: + +``` +rewrite [continue|stop] { + name regex STRING STRING + answer name STRING STRING + [answer value STRING STRING] +} +``` + +Note that the above syntax is strict. For response rewrites, only `name` +rules are allowed to match the question section. The answer rewrite must be +after the name, as in the syntax example. + +##### Example: PTR Response Value Rewrite + +The original response contains the domain `service.consul.` in the `VALUE` part +of the `ANSWER SECTION` + +``` +$ dig @10.1.1.1 30.30.30.10.in-addr.arpa PTR + +;; QUESTION SECTION: +;30.30.30.10.in-addr.arpa. IN PTR + +;; ANSWER SECTION: +30.30.30.10.in-addr.arpa. 60 IN PTR ftp-us-west-1.service.consul. +``` + +The following configuration snippet allows for rewriting of the value +in the `ANSWER SECTION`: + +``` + rewrite stop { + name suffix .arpa .arpa + answer name auto + answer value (.*)\.service\.consul\. {1}.coredns.rocks. + } +``` + +Now, the `VALUE` in the `ANSWER SECTION` has been overwritten in the domain part: + +``` +$ dig @10.1.1.1 30.30.30.10.in-addr.arpa PTR + +;; QUESTION SECTION: +;30.30.30.10.in-addr.arpa. IN PTR + +;; ANSWER SECTION: +30.30.30.10.in-addr.arpa. 60 IN PTR ftp-us-west-1.coredns.rocks. +``` + +#### Multiple Response Rewrites + +`name` and `value` rewrites can be chained by appending multiple answer rewrite +options. For all occurrences but the first one the keyword `answer` might be +omitted. + +```options +answer (auto | (name|value FROM TO)) { [answer] (auto | (name|value FROM TO)) } +``` + +For example: +``` +rewrite [continue|stop] name regex FROM TO answer name FROM TO [answer] value FROM TO +``` + +When using `exact` name rewrite rules, the answer gets rewritten automatically, +and there is no need to define `answer name auto`. But it is still possible to define +additional `answer value` and `answer value` options. + +The rule below rewrites the name in a request from `RED` to `BLUE`, and subsequently +rewrites the name in a corresponding response from `BLUE` to `RED`. The +client in the request would see only `RED` and no `BLUE`. + +``` +rewrite [continue|stop] name exact RED BLUE +``` + +### TTL Field Rewrites + +At times, the need to rewrite a TTL value could arise. For example, a DNS server +may not cache records with a TTL of zero (`0`). An administrator +may want to increase the TTL to ensure it is cached, e.g., by increasing it to 15 seconds. + +In the below example, the TTL in the answers for `coredns.rocks` domain are +being set to `15`: + +``` + rewrite continue { + ttl regex (.*)\.coredns\.rocks 15 + } +``` + +By the same token, an administrator may use this feature to prevent or limit caching by +setting the TTL value really low. + + +The syntax for the TTL rewrite rule is as follows. The meaning of +`exact|prefix|suffix|substring|regex` is the same as with the name rewrite rules. +An omitted type is defaulted to `exact`. + +``` +rewrite [continue|stop] ttl [exact|prefix|suffix|substring|regex] STRING [SECONDS|MIN-MAX] +``` + +It is possible to supply a range of TTL values in the `SECONDS` parameters instead of a single value. +If a range is supplied, the TTL value is set to `MIN` if it is below, or set to `MAX` if it is above. +The TTL value is left unchanged if it is already inside the provided range. +The ranges can be unbounded on either side. + +TTL examples with ranges: +``` +# rewrite TTL to be between 30s and 300s +rewrite ttl example.com. 30-300 + +# cap TTL at 30s +rewrite ttl example.com. -30 # equivalent to rewrite ttl example.com. 0-30 + +# increase TTL to a minimum of 30s +rewrite ttl example.com. 30- + +# set TTL to 30s +rewrite ttl example.com. 30 # equivalent to rewrite ttl example.com. 30-30 +``` + +### RCODE Field Rewrites + +At times, the need to rewrite a RCODE value could arise. For example, a DNS server +may respond with a SERVFAIL instead of NOERROR records when AAAA records are requested. + +In the below example, the rcode value the answer for `coredns.rocks` the replies with SERVFAIL +is being switched to NOERROR. + +This example rewrites all the *.coredns.rocks domain SERVFAIL errors to NOERROR +``` + rewrite continue { + rcode regex (.*)\.coredns\.rocks SERVFAIL NOERROR + } +``` + +The same result numeric values: +``` + rewrite continue { + rcode regex (.*)\.coredns\.rocks 2 0 + } +``` + +The syntax for the RCODE rewrite rule is as follows. The meaning of +`exact|prefix|suffix|substring|regex` is the same as with the name rewrite rules. +An omitted type is defaulted to `exact`. + +``` +rewrite [continue|stop] rcode [exact|prefix|suffix|substring|regex] STRING FROM TO +``` + +The values of FROM and TO can be any of the following, text value or numeric: + +``` + 0 NOERROR + 1 FORMERR + 2 SERVFAIL + 3 NXDOMAIN + 4 NOTIMP + 5 REFUSED + 6 YXDOMAIN + 7 YXRRSET + 8 NXRRSET + 9 NOTAUTH + 10 NOTZONE + 16 BADSIG + 17 BADKEY + 18 BADTIME + 19 BADMODE + 20 BADNAME + 21 BADALG + 22 BADTRUNC + 23 BADCOOKIE +``` + + +## EDNS0 Options + +Using the FIELD edns0, you can set, append, or replace specific EDNS0 options in the request. + +* `replace` will modify any "matching" option with the specified option. The criteria for "matching" varies based on EDNS0 type. +* `append` will add the option only if no matching option exists +* `set` will modify a matching option or add one if none is found + +Currently supported are `EDNS0_LOCAL`, `EDNS0_NSID` and `EDNS0_SUBNET`. + +### EDNS0_LOCAL + +This has two fields, code and data. A match is defined as having the same code. Data may be a string or a variable. + +* A string data is treated as hex if it starts with `0x`. Example: + +~~~ corefile +. { + rewrite edns0 local set 0xffee 0x61626364 + whoami +} +~~~ + +rewrites the first local option with code 0xffee, setting the data to "abcd". This is equivalent to: + +~~~ corefile +. { + rewrite edns0 local set 0xffee abcd +} +~~~ + +* A variable data is specified with a pair of curly brackets `{}`. Following are the supported variables: + {qname}, {qtype}, {client_ip}, {client_port}, {protocol}, {server_ip}, {server_port}. + +* If the metadata plugin is enabled, then labels are supported as variables if they are presented within curly brackets. +The variable data will be replaced with the value associated with that label. If that label is not provided, +the variable will be silently substituted with an empty string. + +Examples: + +~~~ +rewrite edns0 local set 0xffee {client_ip} +~~~ + +The following example uses metadata and an imaginary "some-plugin" that would provide "some-label" as metadata information. + +~~~ +metadata +some-plugin +rewrite edns0 local set 0xffee {some-plugin/some-label} +~~~ + +### EDNS0_NSID + +This has no fields; it will add an NSID option with an empty string for the NSID. If the option already exists +and the action is `replace` or `set`, then the NSID in the option will be set to the empty string. + +### EDNS0_SUBNET + +This has two fields, IPv4 bitmask length and IPv6 bitmask length. The bitmask +length is used to extract the client subnet from the source IP address in the query. + +Example: + +~~~ +rewrite edns0 subnet set 24 56 +~~~ + +* If the query's source IP address is an IPv4 address, the first 24 bits in the IP will be the network subnet. +* If the query's source IP address is an IPv6 address, the first 56 bits in the IP will be the network subnet. + + +### CNAME Field Rewrites + +There might be a scenario where you want the `CNAME` target of the response to be rewritten. You can do this by using the `CNAME` field rewrite. This will generate new answer records according to the new `CNAME` target. + +The syntax for the CNAME rewrite rule is as follows. The meaning of +`exact|prefix|suffix|substring|regex` is the same as with the name rewrite rules. +An omitted type is defaulted to `exact`. + +``` +rewrite [continue|stop] cname [exact|prefix|suffix|substring|regex] FROM TO +``` + +Consider the following `CNAME` rewrite rule with regex type. +``` +rewrite cname regex (.*).cdn.example.net. {1}.other.cdn.com. +``` + +If you were to send the following DNS request without the above rule, an example response would be: + +``` +$ dig @10.1.1.1 my-app.com + +;; QUESTION SECTION: +;my-app.com. IN A + +;; ANSWER SECTION: +my-app.com. 200 IN CNAME my-app.com.cdn.example.net. +my-app.com.cdn.example.net. 300 IN A 20.2.0.1 +my-app.com.cdn.example.net. 300 IN A 20.2.0.2 +``` + +If you were to send the same DNS request with the above rule set up, an example response would be: + +``` +$ dig @10.1.1.1 my-app.com + +;; QUESTION SECTION: +;my-app.com. IN A + +;; ANSWER SECTION: +my-app.com. 200 IN CNAME my-app.com.other.cdn.com. +my-app.com.other.cdn.com. 100 IN A 30.3.1.2 +``` +Note that the answer will contain a completely different set of answer records after rewriting the `CNAME` target. diff --git a/plugin/rewrite/class.go b/plugin/rewrite/class.go new file mode 100644 index 0000000..243a864 --- /dev/null +++ b/plugin/rewrite/class.go @@ -0,0 +1,44 @@ +package rewrite + +import ( + "context" + "fmt" + "strings" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type classRule struct { + fromClass uint16 + toClass uint16 + NextAction string +} + +// newClassRule creates a class matching rule +func newClassRule(nextAction string, args ...string) (Rule, error) { + var from, to uint16 + var ok bool + if from, ok = dns.StringToClass[strings.ToUpper(args[0])]; !ok { + return nil, fmt.Errorf("invalid class %q", strings.ToUpper(args[0])) + } + if to, ok = dns.StringToClass[strings.ToUpper(args[1])]; !ok { + return nil, fmt.Errorf("invalid class %q", strings.ToUpper(args[1])) + } + return &classRule{from, to, nextAction}, nil +} + +// Rewrite rewrites the current request. +func (rule *classRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if rule.fromClass > 0 && rule.toClass > 0 { + if state.Req.Question[0].Qclass == rule.fromClass { + state.Req.Question[0].Qclass = rule.toClass + return nil, RewriteDone + } + } + return nil, RewriteIgnored +} + +// Mode returns the processing mode. +func (rule *classRule) Mode() string { return rule.NextAction } diff --git a/plugin/rewrite/cname_target.go b/plugin/rewrite/cname_target.go new file mode 100644 index 0000000..d57bae3 --- /dev/null +++ b/plugin/rewrite/cname_target.go @@ -0,0 +1,152 @@ +package rewrite + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// UpstreamInt wraps the Upstream API for dependency injection during testing +type UpstreamInt interface { + Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) +} + +// cnameTargetRule is cname target rewrite rule. +type cnameTargetRule struct { + rewriteType string + paramFromTarget string + paramToTarget string + nextAction string + Upstream UpstreamInt // Upstream for looking up external names during the resolution process. +} + +// cnameTargetRuleWithReqState is cname target rewrite rule state +type cnameTargetRuleWithReqState struct { + rule cnameTargetRule + state request.Request + ctx context.Context +} + +func (r *cnameTargetRule) getFromAndToTarget(inputCName string) (from string, to string) { + switch r.rewriteType { + case ExactMatch: + return r.paramFromTarget, r.paramToTarget + case PrefixMatch: + if strings.HasPrefix(inputCName, r.paramFromTarget) { + return inputCName, r.paramToTarget + strings.TrimPrefix(inputCName, r.paramFromTarget) + } + case SuffixMatch: + if strings.HasSuffix(inputCName, r.paramFromTarget) { + return inputCName, strings.TrimSuffix(inputCName, r.paramFromTarget) + r.paramToTarget + } + case SubstringMatch: + if strings.Contains(inputCName, r.paramFromTarget) { + return inputCName, strings.Replace(inputCName, r.paramFromTarget, r.paramToTarget, -1) + } + case RegexMatch: + pattern := regexp.MustCompile(r.paramFromTarget) + regexGroups := pattern.FindStringSubmatch(inputCName) + if len(regexGroups) == 0 { + return "", "" + } + substitution := r.paramToTarget + for groupIndex, groupValue := range regexGroups { + groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}" + substitution = strings.Replace(substitution, groupIndexStr, groupValue, -1) + } + return inputCName, substitution + } + return "", "" +} + +func (r *cnameTargetRuleWithReqState) RewriteResponse(res *dns.Msg, rr dns.RR) { + // logic to rewrite the cname target of dns response + switch rr.Header().Rrtype { + case dns.TypeCNAME: + // rename the target of the cname response + if cname, ok := rr.(*dns.CNAME); ok { + fromTarget, toTarget := r.rule.getFromAndToTarget(cname.Target) + if cname.Target == fromTarget { + // create upstream request with the new target with the same qtype + r.state.Req.Question[0].Name = toTarget + upRes, err := r.rule.Upstream.Lookup(r.ctx, r.state, toTarget, r.state.Req.Question[0].Qtype) + + if err != nil { + log.Errorf("Error upstream request %v", err) + } + + var newAnswer []dns.RR + // iterate over first upstram response + // add the cname record to the new answer + for _, rr := range res.Answer { + if cname, ok := rr.(*dns.CNAME); ok { + // change the target name in the response + cname.Target = toTarget + newAnswer = append(newAnswer, rr) + } + } + // iterate over upstream response received + for _, rr := range upRes.Answer { + if rr.Header().Name == toTarget { + newAnswer = append(newAnswer, rr) + } + } + res.Answer = newAnswer + } + } + } +} + +func newCNAMERule(nextAction string, args ...string) (Rule, error) { + var rewriteType string + var paramFromTarget, paramToTarget string + if len(args) == 3 { + rewriteType = (strings.ToLower(args[0])) + switch rewriteType { + case ExactMatch: + case PrefixMatch: + case SuffixMatch: + case SubstringMatch: + case RegexMatch: + default: + return nil, fmt.Errorf("unknown cname rewrite type: %s", rewriteType) + } + paramFromTarget, paramToTarget = strings.ToLower(args[1]), strings.ToLower(args[2]) + } else if len(args) == 2 { + rewriteType = ExactMatch + paramFromTarget, paramToTarget = strings.ToLower(args[0]), strings.ToLower(args[1]) + } else { + return nil, fmt.Errorf("too few (%d) arguments for a cname rule", len(args)) + } + rule := cnameTargetRule{ + rewriteType: rewriteType, + paramFromTarget: paramFromTarget, + paramToTarget: paramToTarget, + nextAction: nextAction, + Upstream: upstream.New(), + } + return &rule, nil +} + +// Rewrite rewrites the current request. +func (r *cnameTargetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if r != nil && len(r.rewriteType) > 0 && len(r.paramFromTarget) > 0 && len(r.paramToTarget) > 0 { + return ResponseRules{&cnameTargetRuleWithReqState{ + rule: *r, + state: state, + ctx: ctx, + }}, RewriteDone + } + return nil, RewriteIgnored +} + +// Mode returns the processing mode. +func (r *cnameTargetRule) Mode() string { return r.nextAction } diff --git a/plugin/rewrite/cname_target_test.go b/plugin/rewrite/cname_target_test.go new file mode 100644 index 0000000..9eee2b8 --- /dev/null +++ b/plugin/rewrite/cname_target_test.go @@ -0,0 +1,180 @@ +package rewrite + +import ( + "context" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type MockedUpstream struct{} + +func (u *MockedUpstream) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + m := new(dns.Msg) + m.SetReply(state.Req) + m.Authoritative = true + switch state.Req.Question[0].Name { + case "xyz.example.com.": + switch state.Req.Question[0].Qtype { + case dns.TypeA: + m.Answer = []dns.RR{ + test.A("xyz.example.com. 3600 IN A 3.4.5.6"), + } + case dns.TypeAAAA: + m.Answer = []dns.RR{ + test.AAAA("xyz.example.com. 3600 IN AAAA 3a01:7e00::f03c:91ff:fe79:234c"), + } + } + return m, nil + case "bard.google.com.cdn.cloudflare.net.": + m.Answer = []dns.RR{ + test.A("bard.google.com.cdn.cloudflare.net. 1800 IN A 9.7.2.1"), + } + return m, nil + case "www.hosting.xyz.": + m.Answer = []dns.RR{ + test.A("www.hosting.xyz. 500 IN A 20.30.40.50"), + } + return m, nil + case "abcd.zzzz.www.pqrst.": + m.Answer = []dns.RR{ + test.A("abcd.zzzz.www.pqrst. 120 IN A 101.20.5.1"), + test.A("abcd.zzzz.www.pqrst. 120 IN A 101.20.5.2"), + } + return m, nil + case "orders.webapp.eu.org.": + m.Answer = []dns.RR{ + test.A("orders.webapp.eu.org. 120 IN A 20.0.0.9"), + } + return m, nil + } + return &dns.Msg{}, nil +} + +func TestCNameTargetRewrite(t *testing.T) { + rules := []Rule{} + ruleset := []struct { + args []string + expectedType reflect.Type + }{ + {[]string{"continue", "cname", "exact", "def.example.com.", "xyz.example.com."}, reflect.TypeOf(&cnameTargetRule{})}, + {[]string{"continue", "cname", "prefix", "chat.openai.com", "bard.google.com"}, reflect.TypeOf(&cnameTargetRule{})}, + {[]string{"continue", "cname", "suffix", "uvw.", "xyz."}, reflect.TypeOf(&cnameTargetRule{})}, + {[]string{"continue", "cname", "substring", "efgh", "zzzz.www"}, reflect.TypeOf(&cnameTargetRule{})}, + {[]string{"continue", "cname", "regex", `(.*)\.web\.(.*)\.site\.`, `{1}.webapp.{2}.org.`}, reflect.TypeOf(&cnameTargetRule{})}, + } + for i, r := range ruleset { + rule, err := newRule(r.args...) + if err != nil { + t.Fatalf("Rule %d: FAIL, %s: %s", i, r.args, err) + } + if reflect.TypeOf(rule) != r.expectedType { + t.Fatalf("Rule %d: FAIL, %s: rule type mismatch, expected %q, but got %q", i, r.args, r.expectedType, rule) + } + cnameTargetRule := rule.(*cnameTargetRule) + cnameTargetRule.Upstream = &MockedUpstream{} + rules = append(rules, rule) + } + doTestCNameTargetTests(rules, t) +} + +func doTestCNameTargetTests(rules []Rule, t *testing.T) { + tests := []struct { + from string + fromType uint16 + answer []dns.RR + expectedAnswer []dns.RR + }{ + {"abc.example.com", dns.TypeA, + []dns.RR{ + test.CNAME("abc.example.com. 5 IN CNAME def.example.com."), + test.A("def.example.com. 5 IN A 1.2.3.4"), + }, + []dns.RR{ + test.CNAME("abc.example.com. 5 IN CNAME xyz.example.com."), + test.A("xyz.example.com. 3600 IN A 3.4.5.6"), + }, + }, + {"abc.example.com", dns.TypeAAAA, + []dns.RR{ + test.CNAME("abc.example.com. 5 IN CNAME def.example.com."), + test.AAAA("def.example.com. 5 IN AAAA 2a01:7e00::f03c:91ff:fe79:234c"), + }, + []dns.RR{ + test.CNAME("abc.example.com. 5 IN CNAME xyz.example.com."), + test.AAAA("xyz.example.com. 3600 IN AAAA 3a01:7e00::f03c:91ff:fe79:234c"), + }, + }, + {"chat.openai.com", dns.TypeA, + []dns.RR{ + test.CNAME("chat.openai.com. 20 IN CNAME chat.openai.com.cdn.cloudflare.net."), + test.A("chat.openai.com.cdn.cloudflare.net. 30 IN A 23.2.1.2"), + test.A("chat.openai.com.cdn.cloudflare.net. 30 IN A 24.6.0.8"), + }, + []dns.RR{ + test.CNAME("chat.openai.com. 20 IN CNAME bard.google.com.cdn.cloudflare.net."), + test.A("bard.google.com.cdn.cloudflare.net. 1800 IN A 9.7.2.1"), + }, + }, + {"coredns.io", dns.TypeA, + []dns.RR{ + test.CNAME("coredns.io. 100 IN CNAME www.hosting.uvw."), + test.A("www.hosting.uvw. 200 IN A 7.2.3.4"), + }, + []dns.RR{ + test.CNAME("coredns.io. 100 IN CNAME www.hosting.xyz."), + test.A("www.hosting.xyz. 500 IN A 20.30.40.50"), + }, + }, + {"core.dns.rocks", dns.TypeA, + []dns.RR{ + test.CNAME("core.dns.rocks. 200 IN CNAME abcd.efgh.pqrst."), + test.A("abcd.efgh.pqrst. 100 IN A 200.30.45.67"), + }, + []dns.RR{ + test.CNAME("core.dns.rocks. 200 IN CNAME abcd.zzzz.www.pqrst."), + test.A("abcd.zzzz.www.pqrst. 120 IN A 101.20.5.1"), + test.A("abcd.zzzz.www.pqrst. 120 IN A 101.20.5.2"), + }, + }, + {"order.service.eu", dns.TypeA, + []dns.RR{ + test.CNAME("order.service.eu. 200 IN CNAME orders.web.eu.site."), + test.A("orders.web.eu.site. 50 IN A 10.10.15.1"), + }, + []dns.RR{ + test.CNAME("order.service.eu. 200 IN CNAME orders.webapp.eu.org."), + test.A("orders.webapp.eu.org. 120 IN A 20.0.0.9"), + }, + }, + } + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromType) + m.Question[0].Qclass = dns.ClassINET + m.Answer = tc.answer + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + } + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + resp := rec.Msg + if len(resp.Answer) == 0 { + t.Errorf("Test %d: FAIL %s (%d) Expected valid response but received %q", i, tc.from, tc.fromType, resp) + continue + } + if !reflect.DeepEqual(resp.Answer, tc.expectedAnswer) { + t.Errorf("Test %d: FAIL %s (%d) Actual are expected answer does not match, actual: %v, expected: %v", + i, tc.from, tc.fromType, resp.Answer, tc.expectedAnswer) + continue + } + } +} diff --git a/plugin/rewrite/edns0.go b/plugin/rewrite/edns0.go new file mode 100644 index 0000000..85146c7 --- /dev/null +++ b/plugin/rewrite/edns0.go @@ -0,0 +1,371 @@ +// Package rewrite is a plugin for rewriting requests internally to something different. +package rewrite + +import ( + "context" + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/edns" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// edns0LocalRule is a rewrite rule for EDNS0_LOCAL options. +type edns0LocalRule struct { + mode string + action string + code uint16 + data []byte +} + +// edns0VariableRule is a rewrite rule for EDNS0_LOCAL options with variable. +type edns0VariableRule struct { + mode string + action string + code uint16 + variable string +} + +// ends0NsidRule is a rewrite rule for EDNS0_NSID options. +type edns0NsidRule struct { + mode string + action string +} + +// setupEdns0Opt will retrieve the EDNS0 OPT or create it if it does not exist. +func setupEdns0Opt(r *dns.Msg) *dns.OPT { + o := r.IsEdns0() + if o == nil { + r.SetEdns0(4096, false) + o = r.IsEdns0() + } + return o +} + +// Rewrite will alter the request EDNS0 NSID option +func (rule *edns0NsidRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + o := setupEdns0Opt(state.Req) + + for _, s := range o.Option { + if e, ok := s.(*dns.EDNS0_NSID); ok { + if rule.action == Replace || rule.action == Set { + e.Nsid = "" // make sure it is empty for request + return nil, RewriteDone + } + } + } + + // add option if not found + if rule.action == Append || rule.action == Set { + o.Option = append(o.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}) + return nil, RewriteDone + } + + return nil, RewriteIgnored +} + +// Mode returns the processing mode. +func (rule *edns0NsidRule) Mode() string { return rule.mode } + +// Rewrite will alter the request EDNS0 local options. +func (rule *edns0LocalRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + o := setupEdns0Opt(state.Req) + + for _, s := range o.Option { + if e, ok := s.(*dns.EDNS0_LOCAL); ok { + if rule.code == e.Code { + if rule.action == Replace || rule.action == Set { + e.Data = rule.data + return nil, RewriteDone + } + } + } + } + + // add option if not found + if rule.action == Append || rule.action == Set { + o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: rule.data}) + return nil, RewriteDone + } + + return nil, RewriteIgnored +} + +// Mode returns the processing mode. +func (rule *edns0LocalRule) Mode() string { return rule.mode } + +// newEdns0Rule creates an EDNS0 rule of the appropriate type based on the args +func newEdns0Rule(mode string, args ...string) (Rule, error) { + if len(args) < 2 { + return nil, fmt.Errorf("too few arguments for an EDNS0 rule") + } + + ruleType := strings.ToLower(args[0]) + action := strings.ToLower(args[1]) + switch action { + case Append: + case Replace: + case Set: + default: + return nil, fmt.Errorf("invalid action: %q", action) + } + + switch ruleType { + case "local": + if len(args) != 4 { + return nil, fmt.Errorf("EDNS0 local rules require exactly three args") + } + // Check for variable option. + if strings.HasPrefix(args[3], "{") && strings.HasSuffix(args[3], "}") { + return newEdns0VariableRule(mode, action, args[2], args[3]) + } + return newEdns0LocalRule(mode, action, args[2], args[3]) + case "nsid": + if len(args) != 2 { + return nil, fmt.Errorf("EDNS0 NSID rules do not accept args") + } + return &edns0NsidRule{mode: mode, action: action}, nil + case "subnet": + if len(args) != 4 { + return nil, fmt.Errorf("EDNS0 subnet rules require exactly three args") + } + return newEdns0SubnetRule(mode, action, args[2], args[3]) + default: + return nil, fmt.Errorf("invalid rule type %q", ruleType) + } +} + +func newEdns0LocalRule(mode, action, code, data string) (*edns0LocalRule, error) { + c, err := strconv.ParseUint(code, 0, 16) + if err != nil { + return nil, err + } + + decoded := []byte(data) + if strings.HasPrefix(data, "0x") { + decoded, err = hex.DecodeString(data[2:]) + if err != nil { + return nil, err + } + } + + // Add this code to the ones the server supports. + edns.SetSupportedOption(uint16(c)) + + return &edns0LocalRule{mode: mode, action: action, code: uint16(c), data: decoded}, nil +} + +// newEdns0VariableRule creates an EDNS0 rule that handles variable substitution +func newEdns0VariableRule(mode, action, code, variable string) (*edns0VariableRule, error) { + c, err := strconv.ParseUint(code, 0, 16) + if err != nil { + return nil, err + } + //Validate + if !isValidVariable(variable) { + return nil, fmt.Errorf("unsupported variable name %q", variable) + } + + // Add this code to the ones the server supports. + edns.SetSupportedOption(uint16(c)) + + return &edns0VariableRule{mode: mode, action: action, code: uint16(c), variable: variable}, nil +} + +// ruleData returns the data specified by the variable. +func (rule *edns0VariableRule) ruleData(ctx context.Context, state request.Request) ([]byte, error) { + switch rule.variable { + case queryName: + return []byte(state.QName()), nil + + case queryType: + return uint16ToWire(state.QType()), nil + + case clientIP: + return ipToWire(state.Family(), state.IP()) + + case serverIP: + return ipToWire(state.Family(), state.LocalIP()) + + case clientPort: + return portToWire(state.Port()) + + case serverPort: + return portToWire(state.LocalPort()) + + case protocol: + return []byte(state.Proto()), nil + } + + fetcher := metadata.ValueFunc(ctx, rule.variable[1:len(rule.variable)-1]) + if fetcher != nil { + value := fetcher() + if len(value) > 0 { + return []byte(value), nil + } + } + + return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable) +} + +// Rewrite will alter the request EDNS0 local options with specified variables. +func (rule *edns0VariableRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + data, err := rule.ruleData(ctx, state) + if err != nil || data == nil { + return nil, RewriteIgnored + } + + o := setupEdns0Opt(state.Req) + for _, s := range o.Option { + if e, ok := s.(*dns.EDNS0_LOCAL); ok { + if rule.code == e.Code { + if rule.action == Replace || rule.action == Set { + e.Data = data + return nil, RewriteDone + } + return nil, RewriteIgnored + } + } + } + + // add option if not found + if rule.action == Append || rule.action == Set { + o.Option = append(o.Option, &dns.EDNS0_LOCAL{Code: rule.code, Data: data}) + return nil, RewriteDone + } + + return nil, RewriteIgnored +} + +// Mode returns the processing mode. +func (rule *edns0VariableRule) Mode() string { return rule.mode } + +func isValidVariable(variable string) bool { + switch variable { + case + queryName, + queryType, + clientIP, + clientPort, + protocol, + serverIP, + serverPort: + return true + } + // we cannot validate the labels of metadata - but we can verify it has the syntax of a label + if strings.HasPrefix(variable, "{") && strings.HasSuffix(variable, "}") && metadata.IsLabel(variable[1:len(variable)-1]) { + return true + } + return false +} + +// ends0SubnetRule is a rewrite rule for EDNS0 subnet options +type edns0SubnetRule struct { + mode string + v4BitMaskLen uint8 + v6BitMaskLen uint8 + action string +} + +func newEdns0SubnetRule(mode, action, v4BitMaskLen, v6BitMaskLen string) (*edns0SubnetRule, error) { + v4Len, err := strconv.ParseUint(v4BitMaskLen, 0, 16) + if err != nil { + return nil, err + } + // validate V4 length + if v4Len > net.IPv4len*8 { + return nil, fmt.Errorf("invalid IPv4 bit mask length %d", v4Len) + } + + v6Len, err := strconv.ParseUint(v6BitMaskLen, 0, 16) + if err != nil { + return nil, err + } + // validate V6 length + if v6Len > net.IPv6len*8 { + return nil, fmt.Errorf("invalid IPv6 bit mask length %d", v6Len) + } + + return &edns0SubnetRule{mode: mode, action: action, + v4BitMaskLen: uint8(v4Len), v6BitMaskLen: uint8(v6Len)}, nil +} + +// fillEcsData sets the subnet data into the ecs option +func (rule *edns0SubnetRule) fillEcsData(state request.Request, ecs *dns.EDNS0_SUBNET) error { + family := state.Family() + if (family != 1) && (family != 2) { + return fmt.Errorf("unable to fill data for EDNS0 subnet due to invalid IP family") + } + + ecs.Family = uint16(family) + ecs.SourceScope = 0 + + ipAddr := state.IP() + switch family { + case 1: + ipv4Mask := net.CIDRMask(int(rule.v4BitMaskLen), 32) + ipv4Addr := net.ParseIP(ipAddr) + ecs.SourceNetmask = rule.v4BitMaskLen + ecs.Address = ipv4Addr.Mask(ipv4Mask).To4() + case 2: + ipv6Mask := net.CIDRMask(int(rule.v6BitMaskLen), 128) + ipv6Addr := net.ParseIP(ipAddr) + ecs.SourceNetmask = rule.v6BitMaskLen + ecs.Address = ipv6Addr.Mask(ipv6Mask).To16() + } + return nil +} + +// Rewrite will alter the request EDNS0 subnet option. +func (rule *edns0SubnetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + o := setupEdns0Opt(state.Req) + + for _, s := range o.Option { + if e, ok := s.(*dns.EDNS0_SUBNET); ok { + if rule.action == Replace || rule.action == Set { + if rule.fillEcsData(state, e) == nil { + return nil, RewriteDone + } + } + return nil, RewriteIgnored + } + } + + // add option if not found + if rule.action == Append || rule.action == Set { + opt := &dns.EDNS0_SUBNET{Code: dns.EDNS0SUBNET} + if rule.fillEcsData(state, opt) == nil { + o.Option = append(o.Option, opt) + return nil, RewriteDone + } + } + + return nil, RewriteIgnored +} + +// Mode returns the processing mode +func (rule *edns0SubnetRule) Mode() string { return rule.mode } + +// These are all defined actions. +const ( + Replace = "replace" + Set = "set" + Append = "append" +) + +// Supported local EDNS0 variables +const ( + queryName = "{qname}" + queryType = "{qtype}" + clientIP = "{client_ip}" + clientPort = "{client_port}" + protocol = "{protocol}" + serverIP = "{server_ip}" + serverPort = "{server_port}" +) diff --git a/plugin/rewrite/fuzz.go b/plugin/rewrite/fuzz.go new file mode 100644 index 0000000..8e44ebb --- /dev/null +++ b/plugin/rewrite/fuzz.go @@ -0,0 +1,20 @@ +//go:build gofuzz + +package rewrite + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/fuzz" +) + +// Fuzz fuzzes rewrite. +func Fuzz(data []byte) int { + c := caddy.NewTestController("dns", "rewrite edns0 subnet set 24 56") + rules, err := rewriteParse(c) + if err != nil { + return 0 + } + r := Rewrite{Rules: rules} + + return fuzz.Do(r, data) +} diff --git a/plugin/rewrite/log_test.go b/plugin/rewrite/log_test.go new file mode 100644 index 0000000..6ce3627 --- /dev/null +++ b/plugin/rewrite/log_test.go @@ -0,0 +1,5 @@ +package rewrite + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/rewrite/name.go b/plugin/rewrite/name.go new file mode 100644 index 0000000..d3da9c2 --- /dev/null +++ b/plugin/rewrite/name.go @@ -0,0 +1,449 @@ +package rewrite + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// stringRewriter rewrites a string +type stringRewriter interface { + rewriteString(src string) string +} + +// regexStringRewriter can be used to rewrite strings by regex pattern. +// it contains all the information required to detect and execute a rewrite +// on a string. +type regexStringRewriter struct { + pattern *regexp.Regexp + replacement string +} + +var _ stringRewriter = ®exStringRewriter{} + +func newStringRewriter(pattern *regexp.Regexp, replacement string) stringRewriter { + return ®exStringRewriter{pattern, replacement} +} + +func (r *regexStringRewriter) rewriteString(src string) string { + regexGroups := r.pattern.FindStringSubmatch(src) + if len(regexGroups) == 0 { + return src + } + s := r.replacement + for groupIndex, groupValue := range regexGroups { + groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}" + s = strings.Replace(s, groupIndexStr, groupValue, -1) + } + return s +} + +// remapStringRewriter maps a dedicated string to another string +// it also maps a the domain of a sub domain. +type remapStringRewriter struct { + orig string + replacement string +} + +var _ stringRewriter = &remapStringRewriter{} + +func newRemapStringRewriter(orig, replacement string) stringRewriter { + return &remapStringRewriter{orig, replacement} +} + +func (r *remapStringRewriter) rewriteString(src string) string { + if src == r.orig { + return r.replacement + } + if strings.HasSuffix(src, "."+r.orig) { + return src[0:len(src)-len(r.orig)] + r.replacement + } + return src +} + +// suffixStringRewriter maps a dedicated suffix string to another string +type suffixStringRewriter struct { + suffix string + replacement string +} + +var _ stringRewriter = &suffixStringRewriter{} + +func newSuffixStringRewriter(orig, replacement string) stringRewriter { + return &suffixStringRewriter{orig, replacement} +} + +func (r *suffixStringRewriter) rewriteString(src string) string { + if strings.HasSuffix(src, r.suffix) { + return strings.TrimSuffix(src, r.suffix) + r.replacement + } + return src +} + +// nameRewriterResponseRule maps a record name according to a stringRewriter. +type nameRewriterResponseRule struct { + stringRewriter +} + +func (r *nameRewriterResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) { + rr.Header().Name = r.rewriteString(rr.Header().Name) +} + +// valueRewriterResponseRule maps a record value according to a stringRewriter. +type valueRewriterResponseRule struct { + stringRewriter +} + +func (r *valueRewriterResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) { + value := getRecordValueForRewrite(rr) + if value != "" { + new := r.rewriteString(value) + if new != value { + setRewrittenRecordValue(rr, new) + } + } +} + +const ( + // ExactMatch matches only on exact match of the name in the question section of a request + ExactMatch = "exact" + // PrefixMatch matches when the name begins with the matching string + PrefixMatch = "prefix" + // SuffixMatch matches when the name ends with the matching string + SuffixMatch = "suffix" + // SubstringMatch matches on partial match of the name in the question section of a request + SubstringMatch = "substring" + // RegexMatch matches when the name in the question section of a request matches a regular expression + RegexMatch = "regex" + + // AnswerMatch matches an answer rewrite + AnswerMatch = "answer" + // AutoMatch matches the auto name answer rewrite + AutoMatch = "auto" + // NameMatch matches the name answer rewrite + NameMatch = "name" + // ValueMatch matches the value answer rewrite + ValueMatch = "value" +) + +type nameRuleBase struct { + nextAction string + auto bool + replacement string + static ResponseRules +} + +func newNameRuleBase(nextAction string, auto bool, replacement string, staticResponses ResponseRules) nameRuleBase { + return nameRuleBase{ + nextAction: nextAction, + auto: auto, + replacement: replacement, + static: staticResponses, + } +} + +// responseRuleFor create for auto mode dynamically response rewriters for name and value +// reverting the mapping done by the name rewrite rule, which can be found in the state. +func (rule *nameRuleBase) responseRuleFor(state request.Request) (ResponseRules, Result) { + if !rule.auto { + return rule.static, RewriteDone + } + + rewriter := newRemapStringRewriter(state.Req.Question[0].Name, state.Name()) + rules := ResponseRules{ + &nameRewriterResponseRule{rewriter}, + &valueRewriterResponseRule{rewriter}, + } + return append(rules, rule.static...), RewriteDone +} + +// Mode returns the processing nextAction +func (rule *nameRuleBase) Mode() string { return rule.nextAction } + +// exactNameRule rewrites the current request based upon exact match of the name +// in the question section of the request. +type exactNameRule struct { + nameRuleBase + from string +} + +func newExactNameRule(nextAction string, orig, replacement string, answers ResponseRules) Rule { + return &exactNameRule{ + newNameRuleBase(nextAction, true, replacement, answers), + orig, + } +} + +func (rule *exactNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if rule.from == state.Name() { + state.Req.Question[0].Name = rule.replacement + return rule.responseRuleFor(state) + } + return nil, RewriteIgnored +} + +// prefixNameRule rewrites the current request when the name begins with the matching string. +type prefixNameRule struct { + nameRuleBase + prefix string +} + +func newPrefixNameRule(nextAction string, auto bool, prefix, replacement string, answers ResponseRules) Rule { + return &prefixNameRule{ + newNameRuleBase(nextAction, auto, replacement, answers), + prefix, + } +} + +func (rule *prefixNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if strings.HasPrefix(state.Name(), rule.prefix) { + state.Req.Question[0].Name = rule.replacement + strings.TrimPrefix(state.Name(), rule.prefix) + return rule.responseRuleFor(state) + } + return nil, RewriteIgnored +} + +// suffixNameRule rewrites the current request when the name ends with the matching string. +type suffixNameRule struct { + nameRuleBase + suffix string +} + +func newSuffixNameRule(nextAction string, auto bool, suffix, replacement string, answers ResponseRules) Rule { + var rules ResponseRules + if auto { + // for a suffix rewriter better standard response rewrites can be done + // just by using the original suffix/replacement in the opposite order + rewriter := newSuffixStringRewriter(replacement, suffix) + rules = ResponseRules{ + &nameRewriterResponseRule{rewriter}, + &valueRewriterResponseRule{rewriter}, + } + } + return &suffixNameRule{ + newNameRuleBase(nextAction, false, replacement, append(rules, answers...)), + suffix, + } +} + +func (rule *suffixNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if strings.HasSuffix(state.Name(), rule.suffix) { + state.Req.Question[0].Name = strings.TrimSuffix(state.Name(), rule.suffix) + rule.replacement + return rule.responseRuleFor(state) + } + return nil, RewriteIgnored +} + +// substringNameRule rewrites the current request based upon partial match of the +// name in the question section of the request. +type substringNameRule struct { + nameRuleBase + substring string +} + +func newSubstringNameRule(nextAction string, auto bool, substring, replacement string, answers ResponseRules) Rule { + return &substringNameRule{ + newNameRuleBase(nextAction, auto, replacement, answers), + substring, + } +} + +func (rule *substringNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if strings.Contains(state.Name(), rule.substring) { + state.Req.Question[0].Name = strings.Replace(state.Name(), rule.substring, rule.replacement, -1) + return rule.responseRuleFor(state) + } + return nil, RewriteIgnored +} + +// regexNameRule rewrites the current request when the name in the question +// section of the request matches a regular expression. +type regexNameRule struct { + nameRuleBase + pattern *regexp.Regexp +} + +func newRegexNameRule(nextAction string, auto bool, pattern *regexp.Regexp, replacement string, answers ResponseRules) Rule { + return ®exNameRule{ + newNameRuleBase(nextAction, auto, replacement, answers), + pattern, + } +} + +func (rule *regexNameRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + regexGroups := rule.pattern.FindStringSubmatch(state.Name()) + if len(regexGroups) == 0 { + return nil, RewriteIgnored + } + s := rule.replacement + for groupIndex, groupValue := range regexGroups { + groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}" + s = strings.Replace(s, groupIndexStr, groupValue, -1) + } + state.Req.Question[0].Name = s + return rule.responseRuleFor(state) +} + +// newNameRule creates a name matching rule based on exact, partial, or regex match +func newNameRule(nextAction string, args ...string) (Rule, error) { + var matchType, rewriteQuestionFrom, rewriteQuestionTo string + if len(args) < 2 { + return nil, fmt.Errorf("too few arguments for a name rule") + } + if len(args) == 2 { + matchType = ExactMatch + rewriteQuestionFrom = plugin.Name(args[0]).Normalize() + rewriteQuestionTo = plugin.Name(args[1]).Normalize() + } + if len(args) >= 3 { + matchType = strings.ToLower(args[0]) + if matchType == RegexMatch { + rewriteQuestionFrom = args[1] + rewriteQuestionTo = args[2] + } else { + rewriteQuestionFrom = plugin.Name(args[1]).Normalize() + rewriteQuestionTo = plugin.Name(args[2]).Normalize() + } + } + if matchType == ExactMatch || matchType == SuffixMatch { + if !hasClosingDot(rewriteQuestionFrom) { + rewriteQuestionFrom = rewriteQuestionFrom + "." + } + if !hasClosingDot(rewriteQuestionTo) { + rewriteQuestionTo = rewriteQuestionTo + "." + } + } + + var err error + var answers ResponseRules + auto := false + if len(args) > 3 { + auto, answers, err = parseAnswerRules(matchType, args[3:]) + if err != nil { + return nil, err + } + } + + switch matchType { + case ExactMatch: + if _, err := isValidRegexPattern(rewriteQuestionTo, rewriteQuestionFrom); err != nil { + return nil, err + } + return newExactNameRule(nextAction, rewriteQuestionFrom, rewriteQuestionTo, answers), nil + case PrefixMatch: + return newPrefixNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil + case SuffixMatch: + return newSuffixNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil + case SubstringMatch: + return newSubstringNameRule(nextAction, auto, rewriteQuestionFrom, rewriteQuestionTo, answers), nil + case RegexMatch: + rewriteQuestionFromPattern, err := isValidRegexPattern(rewriteQuestionFrom, rewriteQuestionTo) + if err != nil { + return nil, err + } + rewriteQuestionTo := plugin.Name(args[2]).Normalize() + return newRegexNameRule(nextAction, auto, rewriteQuestionFromPattern, rewriteQuestionTo, answers), nil + default: + return nil, fmt.Errorf("name rule supports only exact, prefix, suffix, substring, and regex name matching, received: %s", matchType) + } +} + +func parseAnswerRules(name string, args []string) (auto bool, rules ResponseRules, err error) { + auto = false + arg := 0 + nameRules := 0 + last := "" + if len(args) < 2 { + return false, nil, fmt.Errorf("invalid arguments for %s rule", name) + } + for arg < len(args) { + if last == "" && args[arg] != AnswerMatch { + if last == "" { + return false, nil, fmt.Errorf("exceeded the number of arguments for a non-answer rule argument for %s rule", name) + } + return false, nil, fmt.Errorf("exceeded the number of arguments for %s answer rule for %s rule", last, name) + } + if args[arg] == AnswerMatch { + arg++ + } + if len(args)-arg == 0 { + return false, nil, fmt.Errorf("type missing for answer rule for %s rule", name) + } + last = args[arg] + arg++ + switch last { + case AutoMatch: + auto = true + continue + case NameMatch: + if len(args)-arg < 2 { + return false, nil, fmt.Errorf("%s answer rule for %s rule: 2 arguments required", last, name) + } + rewriteAnswerFrom := args[arg] + rewriteAnswerTo := args[arg+1] + rewriteAnswerFromPattern, err := isValidRegexPattern(rewriteAnswerFrom, rewriteAnswerTo) + rewriteAnswerTo = plugin.Name(rewriteAnswerTo).Normalize() + if err != nil { + return false, nil, fmt.Errorf("%s answer rule for %s rule: %s", last, name, err) + } + rules = append(rules, &nameRewriterResponseRule{newStringRewriter(rewriteAnswerFromPattern, rewriteAnswerTo)}) + arg += 2 + nameRules++ + case ValueMatch: + if len(args)-arg < 2 { + return false, nil, fmt.Errorf("%s answer rule for %s rule: 2 arguments required", last, name) + } + rewriteAnswerFrom := args[arg] + rewriteAnswerTo := args[arg+1] + rewriteAnswerFromPattern, err := isValidRegexPattern(rewriteAnswerFrom, rewriteAnswerTo) + rewriteAnswerTo = plugin.Name(rewriteAnswerTo).Normalize() + if err != nil { + return false, nil, fmt.Errorf("%s answer rule for %s rule: %s", last, name, err) + } + rules = append(rules, &valueRewriterResponseRule{newStringRewriter(rewriteAnswerFromPattern, rewriteAnswerTo)}) + arg += 2 + default: + return false, nil, fmt.Errorf("invalid type %q for answer rule for %s rule", last, name) + } + } + + if auto && nameRules > 0 { + return false, nil, fmt.Errorf("auto name answer rule cannot be combined with explicit name anwer rules") + } + return +} + +// hasClosingDot returns true if s has a closing dot at the end. +func hasClosingDot(s string) bool { + return strings.HasSuffix(s, ".") +} + +// getSubExprUsage returns the number of subexpressions used in s. +func getSubExprUsage(s string) int { + subExprUsage := 0 + for i := 0; i <= 100; i++ { + if strings.Contains(s, "{"+strconv.Itoa(i)+"}") { + subExprUsage++ + } + } + return subExprUsage +} + +// isValidRegexPattern returns a regular expression for pattern matching or errors, if any. +func isValidRegexPattern(rewriteFrom, rewriteTo string) (*regexp.Regexp, error) { + rewriteFromPattern, err := regexp.Compile(rewriteFrom) + if err != nil { + return nil, fmt.Errorf("invalid regex matching pattern: %s", rewriteFrom) + } + if getSubExprUsage(rewriteTo) > rewriteFromPattern.NumSubexp() { + return nil, fmt.Errorf("the rewrite regex pattern (%s) uses more subexpressions than its corresponding matching regex pattern (%s)", rewriteTo, rewriteFrom) + } + return rewriteFromPattern, nil +} diff --git a/plugin/rewrite/name_test.go b/plugin/rewrite/name_test.go new file mode 100644 index 0000000..2dbf1d1 --- /dev/null +++ b/plugin/rewrite/name_test.go @@ -0,0 +1,376 @@ +package rewrite + +import ( + "context" + "strings" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestRewriteIllegalName(t *testing.T) { + r, _ := newNameRule("stop", "example.org.", "example..org.") + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + RevertPolicy: NoRevertPolicy(), + } + + ctx := context.TODO() + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err := rw.ServeDNS(ctx, rec, m) + if !strings.Contains(err.Error(), "invalid name") { + t.Errorf("Expected invalid name, got %s", err.Error()) + } +} + +func TestRewriteNamePrefixSuffix(t *testing.T) { + ctx, close := context.WithCancel(context.TODO()) + defer close() + + tests := []struct { + next string + args []string + question string + expected string + }{ + {"stop", []string{"prefix", "foo", "bar"}, "foo.example.com.", "bar.example.com."}, + {"stop", []string{"prefix", "foo.", "bar."}, "foo.example.com.", "bar.example.com."}, + {"stop", []string{"suffix", "com", "org"}, "foo.example.com.", "foo.example.org."}, + {"stop", []string{"suffix", ".com", ".org"}, "foo.example.com.", "foo.example.org."}, + } + for _, tc := range tests { + r, err := newNameRule(tc.next, tc.args...) + if err != nil { + t.Fatalf("Expected no error, got %s", err) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + RevertPolicy: NoRevertPolicy(), + } + + m := new(dns.Msg) + m.SetQuestion(tc.question, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err = rw.ServeDNS(ctx, rec, m) + if err != nil { + t.Fatalf("Expected no error, got %s", err) + } + actual := rec.Msg.Question[0].Name + if actual != tc.expected { + t.Fatalf("Expected rewrite to %v, got %v", tc.expected, actual) + } + } +} + +func TestRewriteNameNoRewrite(t *testing.T) { + ctx, close := context.WithCancel(context.TODO()) + defer close() + + tests := []struct { + next string + args []string + question string + expected string + }{ + {"stop", []string{"prefix", "foo", "bar"}, "coredns.foo.", "coredns.foo."}, + {"stop", []string{"prefix", "foo", "bar."}, "coredns.foo.", "coredns.foo."}, + {"stop", []string{"suffix", "com", "org"}, "com.coredns.", "com.coredns."}, + {"stop", []string{"suffix", "com", "org."}, "com.coredns.", "com.coredns."}, + {"stop", []string{"substring", "service", "svc"}, "com.coredns.", "com.coredns."}, + } + for i, tc := range tests { + r, err := newNameRule(tc.next, tc.args...) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + } + + m := new(dns.Msg) + m.SetQuestion(tc.question, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err = rw.ServeDNS(ctx, rec, m) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + actual := rec.Msg.Answer[0].Header().Name + if actual != tc.expected { + t.Fatalf("Test %d: Expected answer rewrite to %v, got %v", i, tc.expected, actual) + } + } +} + +func TestRewriteNamePrefixSuffixNoAutoAnswer(t *testing.T) { + ctx, close := context.WithCancel(context.TODO()) + defer close() + + tests := []struct { + next string + args []string + question string + expected string + }{ + {"stop", []string{"prefix", "foo", "bar"}, "foo.example.com.", "bar.example.com."}, + {"stop", []string{"prefix", "foo.", "bar."}, "foo.example.com.", "bar.example.com."}, + {"stop", []string{"suffix", "com", "org"}, "foo.example.com.", "foo.example.org."}, + {"stop", []string{"suffix", ".com", ".org"}, "foo.example.com.", "foo.example.org."}, + {"stop", []string{"suffix", ".ingress.coredns.rocks", "nginx.coredns.rocks"}, "coredns.ingress.coredns.rocks.", "corednsnginx.coredns.rocks."}, + } + for i, tc := range tests { + r, err := newNameRule(tc.next, tc.args...) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + } + + m := new(dns.Msg) + m.SetQuestion(tc.question, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err = rw.ServeDNS(ctx, rec, m) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + actual := rec.Msg.Answer[0].Header().Name + if actual != tc.expected { + t.Fatalf("Test %d: Expected answer rewrite to %v, got %v", i, tc.expected, actual) + } + } +} + +func TestRewriteNamePrefixSuffixAutoAnswer(t *testing.T) { + ctx, close := context.WithCancel(context.TODO()) + defer close() + + tests := []struct { + next string + args []string + question string + rewrite string + expected string + }{ + {"stop", []string{"prefix", "foo", "bar", "answer", "auto"}, "foo.example.com.", "bar.example.com.", "foo.example.com."}, + {"stop", []string{"prefix", "foo.", "bar.", "answer", "auto"}, "foo.example.com.", "bar.example.com.", "foo.example.com."}, + {"stop", []string{"suffix", "com", "org", "answer", "auto"}, "foo.example.com.", "foo.example.org.", "foo.example.com."}, + {"stop", []string{"suffix", ".com", ".org", "answer", "auto"}, "foo.example.com.", "foo.example.org.", "foo.example.com."}, + {"stop", []string{"suffix", ".ingress.coredns.rocks", "nginx.coredns.rocks", "answer", "auto"}, "coredns.ingress.coredns.rocks.", "corednsnginx.coredns.rocks.", "coredns.ingress.coredns.rocks."}, + } + for i, tc := range tests { + r, err := newNameRule(tc.next, tc.args...) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + RevertPolicy: NoRestorePolicy(), + } + + m := new(dns.Msg) + m.SetQuestion(tc.question, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err = rw.ServeDNS(ctx, rec, m) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + rewrite := rec.Msg.Question[0].Name + if rewrite != tc.rewrite { + t.Fatalf("Test %d: Expected question rewrite to %v, got %v", i, tc.rewrite, rewrite) + } + actual := rec.Msg.Answer[0].Header().Name + if actual != tc.expected { + t.Fatalf("Test %d: Expected answer rewrite to %v, got %v", i, tc.expected, actual) + } + } +} + +func TestRewriteNameExactAnswer(t *testing.T) { + ctx, close := context.WithCancel(context.TODO()) + defer close() + + tests := []struct { + next string + args []string + question string + rewrite string + expected string + }{ + {"stop", []string{"exact", "coredns.rocks", "service.consul", "answer", "auto"}, "coredns.rocks.", "service.consul.", "coredns.rocks."}, + {"stop", []string{"exact", "coredns.rocks.", "service.consul.", "answer", "auto"}, "coredns.rocks.", "service.consul.", "coredns.rocks."}, + {"stop", []string{"exact", "coredns.rocks", "service.consul"}, "coredns.rocks.", "service.consul.", "coredns.rocks."}, + {"stop", []string{"exact", "coredns.rocks.", "service.consul."}, "coredns.rocks.", "service.consul.", "coredns.rocks."}, + {"stop", []string{"exact", "coredns.org.", "service.consul."}, "coredns.rocks.", "coredns.rocks.", "coredns.rocks."}, + } + for i, tc := range tests { + r, err := newNameRule(tc.next, tc.args...) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + RevertPolicy: NoRestorePolicy(), + } + + m := new(dns.Msg) + m.SetQuestion(tc.question, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err = rw.ServeDNS(ctx, rec, m) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + rewrite := rec.Msg.Question[0].Name + if rewrite != tc.rewrite { + t.Fatalf("Test %d: Expected question rewrite to %v, got %v", i, tc.rewrite, rewrite) + } + actual := rec.Msg.Answer[0].Header().Name + if actual != tc.expected { + t.Fatalf("Test %d: Expected answer rewrite to %v, got %v", i, tc.expected, actual) + } + } +} + +func TestRewriteNameRegexAnswer(t *testing.T) { + ctx, close := context.WithCancel(context.TODO()) + defer close() + + tests := []struct { + next string + args []string + question string + rewrite string + expected string + }{ + {"stop", []string{"regex", "(.*).coredns.rocks", "{1}.coredns.maps", "answer", "auto"}, "foo.coredns.rocks.", "foo.coredns.maps.", "foo.coredns.rocks."}, + {"stop", []string{"regex", "(.*).coredns.rocks", "{1}.coredns.maps", "answer", "name", "(.*).coredns.maps", "{1}.coredns.works"}, "foo.coredns.rocks.", "foo.coredns.maps.", "foo.coredns.works."}, + {"stop", []string{"regex", "(.*).coredns.rocks", "{1}.coredns.maps"}, "foo.coredns.rocks.", "foo.coredns.maps.", "foo.coredns.maps."}, + } + for i, tc := range tests { + r, err := newNameRule(tc.next, tc.args...) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: []Rule{r}, + RevertPolicy: NoRestorePolicy(), + } + + m := new(dns.Msg) + m.SetQuestion(tc.question, dns.TypeA) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + _, err = rw.ServeDNS(ctx, rec, m) + if err != nil { + t.Fatalf("Test %d: Expected no error, got %s", i, err) + } + rewrite := rec.Msg.Question[0].Name + if rewrite != tc.rewrite { + t.Fatalf("Test %d: Expected question rewrite to %v, got %v", i, tc.rewrite, rewrite) + } + actual := rec.Msg.Answer[0].Header().Name + if actual != tc.expected { + t.Fatalf("Test %d: Expected answer rewrite to %v, got %v", i, tc.expected, actual) + } + } +} + +func TestNewNameRule(t *testing.T) { + tests := []struct { + next string + args []string + expectedFail bool + }{ + {"stop", []string{"exact", "srv3.coredns.rocks", "srv4.coredns.rocks"}, false}, + {"stop", []string{"srv1.coredns.rocks", "srv2.coredns.rocks"}, false}, + {"stop", []string{"suffix", "coredns.rocks", "coredns.rocks."}, false}, + {"stop", []string{"suffix", "coredns.rocks.", "coredns.rocks"}, false}, + {"stop", []string{"suffix", "coredns.rocks.", "coredns.rocks."}, false}, + {"stop", []string{"regex", "srv1.coredns.rocks", "10"}, false}, + {"stop", []string{"regex", "(.*).coredns.rocks", "10"}, false}, + {"stop", []string{"regex", "(.*).coredns.rocks", "{1}.coredns.rocks"}, false}, + {"stop", []string{"regex", "(.*).coredns.rocks", "{1}.{2}.coredns.rocks"}, true}, + {"stop", []string{"regex", "staging.mydomain.com", "aws-loadbalancer-id.us-east-1.elb.amazonaws.com"}, false}, + {"stop", []string{"suffix", "staging.mydomain.com", "coredns.rock", "answer"}, true}, + {"stop", []string{"suffix", "staging.mydomain.com", "coredns.rock", "answer", "name"}, true}, + {"stop", []string{"suffix", "staging.mydomain.com", "coredns.rock", "answer", "other"}, true}, + {"stop", []string{"suffix", "staging.mydomain.com", "coredns.rock", "answer", "auto"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "auto"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name"}, true}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "coredns.rock", "staging.mydomain.com"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.{2}.staging.mydomain.com"}, true}, + + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com", "name", "(.*).coredns.rock"}, true}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com", "answer", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"regex", "staging.mydomain.com", "coredns.rock", "answer", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com", "value", "(.*).coredns.rock"}, true}, + + {"stop", []string{"suffix", "staging.mydomain.com.", "coredns.rock.", "answer", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"suffix", "staging.mydomain.com.", "coredns.rock.", "answer", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com", "answer", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"suffix", "staging.mydomain.com.", "coredns.rock.", "answer", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com", "name", "(.*).coredns.rock", "{1}.staging.mydomain.com"}, false}, + {"stop", []string{"suffix", "staging.mydomain.com.", "coredns.rock.", "answer", "value", "(.*).coredns.rock", "{1}.staging.mydomain.com", "value", "(.*).coredns.rock"}, true}, + } + for i, tc := range tests { + failed := false + rule, err := newNameRule(tc.next, tc.args...) + if err != nil { + failed = true + } + if !failed && !tc.expectedFail { + t.Logf("Test %d: PASS, passed as expected: (%s) %s", i, tc.next, tc.args) + continue + } + if failed && tc.expectedFail { + t.Logf("Test %d: PASS, failed as expected: (%s) %s: %s", i, tc.next, tc.args, err) + continue + } + if failed && !tc.expectedFail { + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v, error=%s", i, tc.expectedFail, failed, tc.next, tc.args, rule, err) + } + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v", i, tc.expectedFail, failed, tc.next, tc.args, rule) + } + for i, tc := range tests { + failed := false + tc.args = append([]string{tc.next, "name"}, tc.args...) + rule, err := newRule(tc.args...) + if err != nil { + failed = true + } + if !failed && !tc.expectedFail { + t.Logf("Test %d: PASS, passed as expected: (%s) %s", i, tc.next, tc.args) + continue + } + if failed && tc.expectedFail { + t.Logf("Test %d: PASS, failed as expected: (%s) %s: %s", i, tc.next, tc.args, err) + continue + } + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v", i, tc.expectedFail, failed, tc.next, tc.args, rule) + } +} diff --git a/plugin/rewrite/rcode.go b/plugin/rewrite/rcode.go new file mode 100644 index 0000000..814b95f --- /dev/null +++ b/plugin/rewrite/rcode.go @@ -0,0 +1,178 @@ +package rewrite + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type rcodeResponseRule struct { + old int + new int +} + +func (r *rcodeResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) { + if r.old == res.MsgHdr.Rcode { + res.MsgHdr.Rcode = r.new + } +} + +type rcodeRuleBase struct { + nextAction string + response rcodeResponseRule +} + +func newRCodeRuleBase(nextAction string, old, new int) rcodeRuleBase { + return rcodeRuleBase{ + nextAction: nextAction, + response: rcodeResponseRule{old: old, new: new}, + } +} + +func (rule *rcodeRuleBase) responseRule(match bool) (ResponseRules, Result) { + if match { + return ResponseRules{&rule.response}, RewriteDone + } + return nil, RewriteIgnored +} + +// Mode returns the processing nextAction +func (rule *rcodeRuleBase) Mode() string { return rule.nextAction } + +type exactRCodeRule struct { + rcodeRuleBase + From string +} + +type prefixRCodeRule struct { + rcodeRuleBase + Prefix string +} + +type suffixRCodeRule struct { + rcodeRuleBase + Suffix string +} + +type substringRCodeRule struct { + rcodeRuleBase + Substring string +} + +type regexRCodeRule struct { + rcodeRuleBase + Pattern *regexp.Regexp +} + +// Rewrite rewrites the current request based upon exact match of the name +// in the question section of the request. +func (rule *exactRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(rule.From == state.Name()) +} + +// Rewrite rewrites the current request when the name begins with the matching string. +func (rule *prefixRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(strings.HasPrefix(state.Name(), rule.Prefix)) +} + +// Rewrite rewrites the current request when the name ends with the matching string. +func (rule *suffixRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(strings.HasSuffix(state.Name(), rule.Suffix)) +} + +// Rewrite rewrites the current request based upon partial match of the +// name in the question section of the request. +func (rule *substringRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(strings.Contains(state.Name(), rule.Substring)) +} + +// Rewrite rewrites the current request when the name in the question +// section of the request matches a regular expression. +func (rule *regexRCodeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(len(rule.Pattern.FindStringSubmatch(state.Name())) != 0) +} + +// newRCodeRule creates a name matching rule based on exact, partial, or regex match +func newRCodeRule(nextAction string, args ...string) (Rule, error) { + if len(args) < 3 { + return nil, fmt.Errorf("too few (%d) arguments for a rcode rule", len(args)) + } + var oldStr, newStr string + if len(args) == 3 { + oldStr, newStr = args[1], args[2] + } + if len(args) == 4 { + oldStr, newStr = args[2], args[3] + } + old, valid := isValidRCode(oldStr) + if !valid { + return nil, fmt.Errorf("invalid matching RCODE '%s' for a rcode rule", oldStr) + } + new, valid := isValidRCode(newStr) + if !valid { + return nil, fmt.Errorf("invalid replacement RCODE '%s' for a rcode rule", newStr) + } + if len(args) == 4 { + switch strings.ToLower(args[0]) { + case ExactMatch: + return &exactRCodeRule{ + newRCodeRuleBase(nextAction, old, new), + plugin.Name(args[1]).Normalize(), + }, nil + case PrefixMatch: + return &prefixRCodeRule{ + newRCodeRuleBase(nextAction, old, new), + plugin.Name(args[1]).Normalize(), + }, nil + case SuffixMatch: + return &suffixRCodeRule{ + newRCodeRuleBase(nextAction, old, new), + plugin.Name(args[1]).Normalize(), + }, nil + case SubstringMatch: + return &substringRCodeRule{ + newRCodeRuleBase(nextAction, old, new), + plugin.Name(args[1]).Normalize(), + }, nil + case RegexMatch: + regexPattern, err := regexp.Compile(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid regex pattern in a rcode rule: %s", args[1]) + } + return ®exRCodeRule{ + newRCodeRuleBase(nextAction, old, new), + regexPattern, + }, nil + default: + return nil, fmt.Errorf("rcode rule supports only exact, prefix, suffix, substring, and regex name matching") + } + } + if len(args) > 4 { + return nil, fmt.Errorf("many few arguments for a rcode rule") + } + return &exactRCodeRule{ + newRCodeRuleBase(nextAction, old, new), + plugin.Name(args[0]).Normalize(), + }, nil +} + +// validRCode returns true if v is valid RCode value. +func isValidRCode(v string) (int, bool) { + i, err := strconv.ParseUint(v, 10, 32) + // try parsing integer based rcode + if err == nil && i <= 23 { + return int(i), true + } + + if RCodeInt, ok := dns.StringToRcode[strings.ToUpper(v)]; ok { + return RCodeInt, true + } + return 0, false +} diff --git a/plugin/rewrite/rcode_test.go b/plugin/rewrite/rcode_test.go new file mode 100644 index 0000000..e402607 --- /dev/null +++ b/plugin/rewrite/rcode_test.go @@ -0,0 +1,72 @@ +package rewrite + +import ( + "testing" + + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestNewRCodeRule(t *testing.T) { + tests := []struct { + next string + args []string + expectedFail bool + }{ + {"stop", []string{"numeric.rcode.coredns.rocks", "2", "0"}, false}, + {"stop", []string{"too.few.rcode.coredns.rocks", "2"}, true}, + {"stop", []string{"exact", "too.many.rcode.coredns.rocks", "2", "1", "0"}, true}, + {"stop", []string{"exact", "match.string.rcode.coredns.rocks", "SERVFAIL", "NOERROR"}, false}, + {"continue", []string{"regex", `(regex)\.rcode\.(coredns)\.(rocks)`, "FORMERR", "NOERROR"}, false}, + {"stop", []string{"invalid.rcode.coredns.rocks", "random", "nothing"}, true}, + } + for i, tc := range tests { + failed := false + rule, err := newRCodeRule(tc.next, tc.args...) + if err != nil { + failed = true + } + if !failed && !tc.expectedFail { + continue + } + if failed && tc.expectedFail { + continue + } + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v, err=%v", i, tc.expectedFail, failed, tc.next, tc.args, rule, err) + } + for i, tc := range tests { + failed := false + tc.args = append([]string{tc.next, "rcode"}, tc.args...) + rule, err := newRule(tc.args...) + if err != nil { + failed = true + } + if !failed && !tc.expectedFail { + continue + } + if failed && tc.expectedFail { + continue + } + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v, err=%v", i, tc.expectedFail, failed, tc.next, tc.args, rule, err) + } +} + +func TestRCodeRewrite(t *testing.T) { + rule, err := newRCodeRule("stop", []string{"exact", "srv1.coredns.rocks", "SERVFAIL", "FORMERR"}...) + + m := new(dns.Msg) + m.SetQuestion("srv1.coredns.rocks.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + m.Answer = []dns.RR{test.A("srv1.coredns.rocks. 5 IN A 10.0.0.1")} + m.MsgHdr.Rcode = dns.RcodeServerFailure + request := request.Request{Req: m} + + rcRule, _ := rule.(*exactRCodeRule) + var rr dns.RR + rcRule.response.RewriteResponse(request.Req, rr) + if request.Req.MsgHdr.Rcode != dns.RcodeFormatError { + t.Fatalf("RCode rewrite did not apply changes, request=%#v, err=%v", request.Req, err) + } +} diff --git a/plugin/rewrite/reverter.go b/plugin/rewrite/reverter.go new file mode 100644 index 0000000..853d96d --- /dev/null +++ b/plugin/rewrite/reverter.go @@ -0,0 +1,146 @@ +package rewrite + +import ( + "github.com/miekg/dns" +) + +// RevertPolicy controls the overall reverting process +type RevertPolicy interface { + DoRevert() bool + DoQuestionRestore() bool +} + +type revertPolicy struct { + noRevert bool + noRestore bool +} + +func (p revertPolicy) DoRevert() bool { + return !p.noRevert +} + +func (p revertPolicy) DoQuestionRestore() bool { + return !p.noRestore +} + +// NoRevertPolicy disables all response rewrite rules +func NoRevertPolicy() RevertPolicy { + return revertPolicy{true, false} +} + +// NoRestorePolicy disables the question restoration during the response rewrite +func NoRestorePolicy() RevertPolicy { + return revertPolicy{false, true} +} + +// NewRevertPolicy creates a new reverter policy by dynamically specifying all +// options. +func NewRevertPolicy(noRevert, noRestore bool) RevertPolicy { + return revertPolicy{noRestore: noRestore, noRevert: noRevert} +} + +// ResponseRule contains a rule to rewrite a response with. +type ResponseRule interface { + RewriteResponse(res *dns.Msg, rr dns.RR) +} + +// ResponseRules describes an ordered list of response rules to apply +// after a name rewrite +type ResponseRules = []ResponseRule + +// ResponseReverter reverses the operations done on the question section of a packet. +// This is need because the client will otherwise disregards the response, i.e. +// dig will complain with ';; Question section mismatch: got example.org/HINFO/IN' +type ResponseReverter struct { + dns.ResponseWriter + originalQuestion dns.Question + ResponseRules ResponseRules + revertPolicy RevertPolicy +} + +// NewResponseReverter returns a pointer to a new ResponseReverter. +func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg, policy RevertPolicy) *ResponseReverter { + return &ResponseReverter{ + ResponseWriter: w, + originalQuestion: r.Question[0], + revertPolicy: policy, + } +} + +// WriteMsg records the status code and calls the underlying ResponseWriter's WriteMsg method. +func (r *ResponseReverter) WriteMsg(res1 *dns.Msg) error { + // Deep copy 'res' as to not (e.g). rewrite a message that's also stored in the cache. + res := res1.Copy() + + if r.revertPolicy.DoQuestionRestore() { + res.Question[0] = r.originalQuestion + } + if len(r.ResponseRules) > 0 { + for _, rr := range res.Ns { + r.rewriteResourceRecord(res, rr) + } + for _, rr := range res.Answer { + r.rewriteResourceRecord(res, rr) + } + for _, rr := range res.Extra { + r.rewriteResourceRecord(res, rr) + } + } + return r.ResponseWriter.WriteMsg(res) +} + +func (r *ResponseReverter) rewriteResourceRecord(res *dns.Msg, rr dns.RR) { + for _, rule := range r.ResponseRules { + rule.RewriteResponse(res, rr) + } +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseReverter) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +func getRecordValueForRewrite(rr dns.RR) (name string) { + switch rr.Header().Rrtype { + case dns.TypeSRV: + return rr.(*dns.SRV).Target + case dns.TypeMX: + return rr.(*dns.MX).Mx + case dns.TypeCNAME: + return rr.(*dns.CNAME).Target + case dns.TypeNS: + return rr.(*dns.NS).Ns + case dns.TypeDNAME: + return rr.(*dns.DNAME).Target + case dns.TypeNAPTR: + return rr.(*dns.NAPTR).Replacement + case dns.TypeSOA: + return rr.(*dns.SOA).Ns + case dns.TypePTR: + return rr.(*dns.PTR).Ptr + default: + return "" + } +} + +func setRewrittenRecordValue(rr dns.RR, value string) { + switch rr.Header().Rrtype { + case dns.TypeSRV: + rr.(*dns.SRV).Target = value + case dns.TypeMX: + rr.(*dns.MX).Mx = value + case dns.TypeCNAME: + rr.(*dns.CNAME).Target = value + case dns.TypeNS: + rr.(*dns.NS).Ns = value + case dns.TypeDNAME: + rr.(*dns.DNAME).Target = value + case dns.TypeNAPTR: + rr.(*dns.NAPTR).Replacement = value + case dns.TypeSOA: + rr.(*dns.SOA).Ns = value + case dns.TypePTR: + rr.(*dns.PTR).Ptr = value + } +} diff --git a/plugin/rewrite/reverter_test.go b/plugin/rewrite/reverter_test.go new file mode 100644 index 0000000..9156728 --- /dev/null +++ b/plugin/rewrite/reverter_test.go @@ -0,0 +1,177 @@ +package rewrite + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +var tests = []struct { + from string + fromType uint16 + answer []dns.RR + to string + toType uint16 + noRevert bool +}{ + {"core.dns.rocks", dns.TypeA, []dns.RR{test.A("dns.core.rocks. 5 IN A 10.0.0.1")}, "core.dns.rocks", dns.TypeA, false}, + {"core.dns.rocks", dns.TypeSRV, []dns.RR{test.SRV("dns.core.rocks. 5 IN SRV 0 100 100 srv1.dns.core.rocks.")}, "core.dns.rocks", dns.TypeSRV, false}, + {"core.dns.rocks", dns.TypeA, []dns.RR{test.A("core.dns.rocks. 5 IN A 10.0.0.1")}, "dns.core.rocks.", dns.TypeA, true}, + {"core.dns.rocks", dns.TypeSRV, []dns.RR{test.SRV("core.dns.rocks. 5 IN SRV 0 100 100 srv1.dns.core.rocks.")}, "dns.core.rocks.", dns.TypeSRV, true}, + {"core.dns.rocks", dns.TypeHINFO, []dns.RR{test.HINFO("core.dns.rocks. 5 HINFO INTEL-64 \"RHEL 7.4\"")}, "core.dns.rocks", dns.TypeHINFO, false}, + {"core.dns.rocks", dns.TypeA, []dns.RR{ + test.A("dns.core.rocks. 5 IN A 10.0.0.1"), + test.A("dns.core.rocks. 5 IN A 10.0.0.2"), + }, "core.dns.rocks", dns.TypeA, false}, +} + +func TestResponseReverter(t *testing.T) { + rules := []Rule{} + r, _ := newNameRule("stop", "regex", `(core)\.(dns)\.(rocks)`, "{2}.{1}.{3}", "answer", "name", `(dns)\.(core)\.(rocks)`, "{2}.{1}.{3}") + rules = append(rules, r) + + doReverterTests(rules, t) + + rules = []Rule{} + r, _ = newNameRule("continue", "regex", `(core)\.(dns)\.(rocks)`, "{2}.{1}.{3}", "answer", "name", `(dns)\.(core)\.(rocks)`, "{2}.{1}.{3}") + rules = append(rules, r) + + doReverterTests(rules, t) +} + +func doReverterTests(rules []Rule, t *testing.T) { + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromType) + m.Question[0].Qclass = dns.ClassINET + m.Answer = tc.answer + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + RevertPolicy: NewRevertPolicy(tc.noRevert, false), + } + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + resp := rec.Msg + if resp.Question[0].Name != tc.to { + t.Errorf("Test %d: Expected Name to be %q but was %q", i, tc.to, resp.Question[0].Name) + } + if resp.Question[0].Qtype != tc.toType { + t.Errorf("Test %d: Expected Type to be '%d' but was '%d'", i, tc.toType, resp.Question[0].Qtype) + } + } +} + +var valueTests = []struct { + from string + fromType uint16 + answer []dns.RR + extra []dns.RR + to string + toType uint16 + noRevert bool + expectValue string + expectAnswerType uint16 + expectAddlName string +}{ + {"my.domain.uk.", dns.TypeSRV, []dns.RR{test.SRV("my.cluster.local. 5 IN SRV 0 100 100 srv1.my.cluster.local.")}, []dns.RR{test.A("srv1.my.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeSRV, false, "srv1.my.domain.uk.", dns.TypeSRV, "srv1.my.domain.uk."}, + {"my.domain.uk.", dns.TypeSRV, []dns.RR{test.SRV("my.cluster.local. 5 IN SRV 0 100 100 srv1.my.cluster.local.")}, []dns.RR{test.A("srv1.my.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeSRV, true, "srv1.my.cluster.local.", dns.TypeSRV, "srv1.my.cluster.local."}, + {"my.domain.uk.", dns.TypeANY, []dns.RR{test.CNAME("my.cluster.local. 3600 IN CNAME cname.cluster.local.")}, []dns.RR{test.A("cname.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeANY, false, "cname.domain.uk.", dns.TypeCNAME, "cname.domain.uk."}, + {"my.domain.uk.", dns.TypeANY, []dns.RR{test.CNAME("my.cluster.local. 3600 IN CNAME cname.cluster.local.")}, []dns.RR{test.A("cname.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeANY, true, "cname.cluster.local.", dns.TypeCNAME, "cname.cluster.local."}, + {"my.domain.uk.", dns.TypeANY, []dns.RR{test.DNAME("my.cluster.local. 3600 IN DNAME dname.cluster.local.")}, []dns.RR{test.A("dname.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeANY, false, "dname.domain.uk.", dns.TypeDNAME, "dname.domain.uk."}, + {"my.domain.uk.", dns.TypeANY, []dns.RR{test.DNAME("my.cluster.local. 3600 IN DNAME dname.cluster.local.")}, []dns.RR{test.A("dname.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeANY, true, "dname.cluster.local.", dns.TypeDNAME, "dname.cluster.local."}, + {"my.domain.uk.", dns.TypeMX, []dns.RR{test.MX("my.cluster.local. 3600 IN MX 1 mx1.cluster.local.")}, []dns.RR{test.A("mx1.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeMX, false, "mx1.domain.uk.", dns.TypeMX, "mx1.domain.uk."}, + {"my.domain.uk.", dns.TypeMX, []dns.RR{test.MX("my.cluster.local. 3600 IN MX 1 mx1.cluster.local.")}, []dns.RR{test.A("mx1.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeMX, true, "mx1.cluster.local.", dns.TypeMX, "mx1.cluster.local."}, + {"my.domain.uk.", dns.TypeANY, []dns.RR{test.NS("my.cluster.local. 3600 IN NS ns1.cluster.local.")}, []dns.RR{test.A("ns1.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeANY, false, "ns1.domain.uk.", dns.TypeNS, "ns1.domain.uk."}, + {"my.domain.uk.", dns.TypeANY, []dns.RR{test.NS("my.cluster.local. 3600 IN NS ns1.cluster.local.")}, []dns.RR{test.A("ns1.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeANY, true, "ns1.cluster.local.", dns.TypeNS, "ns1.cluster.local."}, + {"my.domain.uk.", dns.TypeSOA, []dns.RR{test.SOA("my.cluster.local. 1800 IN SOA ns1.cluster.local. admin.cluster.local. 1502165581 14400 3600 604800 14400")}, []dns.RR{test.A("ns1.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeSOA, false, "ns1.domain.uk.", dns.TypeSOA, "ns1.domain.uk."}, + {"my.domain.uk.", dns.TypeSOA, []dns.RR{test.SOA("my.cluster.local. 1800 IN SOA ns1.cluster.local. admin.cluster.local. 1502165581 14400 3600 604800 14400")}, []dns.RR{test.A("ns1.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeSOA, true, "ns1.cluster.local.", dns.TypeSOA, "ns1.cluster.local."}, + {"my.domain.uk.", dns.TypeNAPTR, []dns.RR{test.NAPTR("my.cluster.local. 100 IN NAPTR 100 10 \"S\" \"SIP+D2U\" \"!^.*$!sip:[email protected]!\" _sip._udp.cluster.local.")}, []dns.RR{test.A("ns1.cluster.local. 5 IN A 10.0.0.1")}, "my.domain.uk.", dns.TypeNAPTR, false, "_sip._udp.domain.uk.", dns.TypeNAPTR, "ns1.domain.uk."}, + {"my.domain.uk.", dns.TypeNAPTR, []dns.RR{test.NAPTR("my.cluster.local. 100 IN NAPTR 100 10 \"S\" \"SIP+D2U\" \"!^.*$!sip:[email protected]!\" _sip._udp.cluster.local.")}, []dns.RR{test.A("ns1.cluster.local. 5 IN A 10.0.0.1")}, "my.cluster.local.", dns.TypeNAPTR, true, "_sip._udp.cluster.local.", dns.TypeNAPTR, "ns1.cluster.local."}, +} + +func TestValueResponseReverter(t *testing.T) { + rules := []Rule{} + r, err := newNameRule("stop", "regex", `(.*)\.domain\.uk`, "{1}.cluster.local", "answer", "name", `(.*)\.cluster\.local`, "{1}.domain.uk", "answer", "value", `(.*)\.cluster\.local`, "{1}.domain.uk") + if err != nil { + t.Errorf("cannot parse rule: %s", err) + return + } + rules = append(rules, r) + + doValueReverterTests("stop", rules, t) + + rules = []Rule{} + r, err = newNameRule("continue", "regex", `(.*)\.domain\.uk`, "{1}.cluster.local", "answer", "name", `(.*)\.cluster\.local`, "{1}.domain.uk", "answer", "value", `(.*)\.cluster\.local`, "{1}.domain.uk") + if err != nil { + t.Errorf("cannot parse rule: %s", err) + return + } + rules = append(rules, r) + + doValueReverterTests("continue", rules, t) + + rules = []Rule{} + r, err = newNameRule("stop", "suffix", `.domain.uk`, ".cluster.local", "answer", "auto", "answer", "value", `(.*)\.cluster\.local`, "{1}.domain.uk") + if err != nil { + t.Errorf("cannot parse rule: %s", err) + return + } + rules = append(rules, r) + + doValueReverterTests("suffix", rules, t) +} + +func doValueReverterTests(name string, rules []Rule, t *testing.T) { + ctx := context.TODO() + for i, tc := range valueTests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromType) + m.Question[0].Qclass = dns.ClassINET + m.Answer = tc.answer + m.Extra = tc.extra + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + RevertPolicy: NewRevertPolicy(tc.noRevert, false), + } + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + resp := rec.Msg + if resp.Question[0].Name != tc.to { + t.Errorf("Test %s.%d: Expected Name to be %q but was %q", name, i, tc.to, resp.Question[0].Name) + } + if resp.Question[0].Qtype != tc.toType { + t.Errorf("Test %s.%d: Expected Type to be '%d' but was '%d'", name, i, tc.toType, resp.Question[0].Qtype) + } + + if len(resp.Answer) <= 0 { + t.Errorf("Test %s.%d: No Answers", name, i) + return + } + if len(resp.Answer) > 0 && resp.Answer[0].Header().Rrtype != tc.expectAnswerType { + t.Errorf("Test %s.%d: Unexpected Answer Record Type %d", name, i, resp.Answer[0].Header().Rrtype) + return + } + + value := getRecordValueForRewrite(resp.Answer[0]) + if value != tc.expectValue { + t.Errorf("Test %s.%d: Expected Target to be '%s' but was '%s'", name, i, tc.expectValue, value) + } + + if len(resp.Extra) <= 0 || resp.Extra[0].Header().Rrtype != dns.TypeA { + t.Errorf("Test %s.%d: Unexpected Additional Record Type / No Additional Records", name, i) + return + } + + if resp.Extra[0].Header().Name != tc.expectAddlName { + t.Errorf("Test %s.%d: Expected Extra Name to be %q but was %q", name, i, tc.expectAddlName, resp.Extra[0].Header().Name) + } + } +} diff --git a/plugin/rewrite/rewrite.go b/plugin/rewrite/rewrite.go new file mode 100644 index 0000000..edb1181 --- /dev/null +++ b/plugin/rewrite/rewrite.go @@ -0,0 +1,149 @@ +package rewrite + +import ( + "context" + "fmt" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Result is the result of a rewrite +type Result int + +const ( + // RewriteIgnored is returned when rewrite is not done on request. + RewriteIgnored Result = iota + // RewriteDone is returned when rewrite is done on request. + RewriteDone +) + +// These are defined processing mode. +const ( + // Processing should stop after completing this rule + Stop = "stop" + // Processing should continue to next rule + Continue = "continue" +) + +// Rewrite is a plugin to rewrite requests internally before being handled. +type Rewrite struct { + Next plugin.Handler + Rules []Rule + RevertPolicy +} + +// ServeDNS implements the plugin.Handler interface. +func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if rw.RevertPolicy == nil { + rw.RevertPolicy = NewRevertPolicy(false, false) + } + wr := NewResponseReverter(w, r, rw.RevertPolicy) + state := request.Request{W: w, Req: r} + + for _, rule := range rw.Rules { + respRules, result := rule.Rewrite(ctx, state) + if result == RewriteDone { + if _, ok := dns.IsDomainName(state.Req.Question[0].Name); !ok { + err := fmt.Errorf("invalid name after rewrite: %s", state.Req.Question[0].Name) + state.Req.Question[0] = wr.originalQuestion + return dns.RcodeServerFailure, err + } + wr.ResponseRules = append(wr.ResponseRules, respRules...) + if rule.Mode() == Stop { + if !rw.RevertPolicy.DoRevert() { + return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) + } + rcode, err := plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r) + if plugin.ClientWrite(rcode) { + return rcode, err + } + // The next plugins didn't write a response, so write one now with the ResponseReverter. + // If server.ServeDNS does this then it will create an answer mismatch. + res := new(dns.Msg).SetRcode(r, rcode) + state.SizeAndDo(res) + wr.WriteMsg(res) + // return success, so server does not write a second error response to client + return dns.RcodeSuccess, err + } + } + } + if !rw.RevertPolicy.DoRevert() || len(wr.ResponseRules) == 0 { + return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r) + } + return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r) +} + +// Name implements the Handler interface. +func (rw Rewrite) Name() string { return "rewrite" } + +// Rule describes a rewrite rule. +type Rule interface { + // Rewrite rewrites the current request. + Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) + // Mode returns the processing mode stop or continue. + Mode() string +} + +func newRule(args ...string) (Rule, error) { + if len(args) == 0 { + return nil, fmt.Errorf("no rule type specified for rewrite") + } + + arg0 := strings.ToLower(args[0]) + var ruleType string + var expectNumArgs, startArg int + mode := Stop + switch arg0 { + case Continue: + mode = Continue + if len(args) < 2 { + return nil, fmt.Errorf("continue rule must begin with a rule type") + } + ruleType = strings.ToLower(args[1]) + expectNumArgs = len(args) - 1 + startArg = 2 + case Stop: + if len(args) < 2 { + return nil, fmt.Errorf("stop rule must begin with a rule type") + } + ruleType = strings.ToLower(args[1]) + expectNumArgs = len(args) - 1 + startArg = 2 + default: + // for backward compatibility + ruleType = arg0 + expectNumArgs = len(args) + startArg = 1 + } + + switch ruleType { + case "answer": + return nil, fmt.Errorf("response rewrites must begin with a name rule") + case "name": + return newNameRule(mode, args[startArg:]...) + case "class": + if expectNumArgs != 3 { + return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType) + } + return newClassRule(mode, args[startArg:]...) + case "type": + if expectNumArgs != 3 { + return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType) + } + return newTypeRule(mode, args[startArg:]...) + case "edns0": + return newEdns0Rule(mode, args[startArg:]...) + case "ttl": + return newTTLRule(mode, args[startArg:]...) + case "cname": + return newCNAMERule(mode, args[startArg:]...) + case "rcode": + return newRCodeRule(mode, args[startArg:]...) + default: + return nil, fmt.Errorf("invalid rule type %q", args[0]) + } +} diff --git a/plugin/rewrite/rewrite_test.go b/plugin/rewrite/rewrite_test.go new file mode 100644 index 0000000..03d4fff --- /dev/null +++ b/plugin/rewrite/rewrite_test.go @@ -0,0 +1,747 @@ +package rewrite + +import ( + "bytes" + "context" + "fmt" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func msgPrinter(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if len(r.Answer) == 0 { + r.Answer = []dns.RR{ + test.A(fmt.Sprintf("%s 5 IN A 10.0.0.1", r.Question[0].Name)), + } + } + w.WriteMsg(r) + return 0, nil +} + +func TestNewRule(t *testing.T) { + tests := []struct { + args []string + shouldError bool + expType reflect.Type + }{ + {[]string{}, true, nil}, + {[]string{"foo"}, true, nil}, + {[]string{"name"}, true, nil}, + {[]string{"name", "a.com"}, true, nil}, + {[]string{"name", "a.com", "b.com", "c.com"}, true, nil}, + {[]string{"name", "a.com", "b.com"}, false, reflect.TypeOf(&exactNameRule{})}, + {[]string{"name", "exact", "a.com", "b.com"}, false, reflect.TypeOf(&exactNameRule{})}, + {[]string{"name", "prefix", "a.com", "b.com"}, false, reflect.TypeOf(&prefixNameRule{})}, + {[]string{"name", "suffix", "a.com", "b.com"}, false, reflect.TypeOf(&suffixNameRule{})}, + {[]string{"name", "substring", "a.com", "b.com"}, false, reflect.TypeOf(&substringNameRule{})}, + {[]string{"name", "regex", "([a])\\.com", "new-{1}.com"}, false, reflect.TypeOf(®exNameRule{})}, + {[]string{"name", "regex", "([a]\\.com", "new-{1}.com"}, true, nil}, + {[]string{"name", "regex", "(dns)\\.(core)\\.(rocks)", "{2}.{1}.{3}", "answer", "name", "(core)\\.(dns)\\.(rocks)", "{2}.{1}.{3}"}, false, reflect.TypeOf(®exNameRule{})}, + {[]string{"name", "regex", "(adns)\\.(core)\\.(rocks)", "{2}.{1}.{3}", "answer", "name", "(core)\\.(adns)\\.(rocks)", "{2}.{1}.{3}", "too.long", "way.too.long"}, true, nil}, + {[]string{"name", "regex", "(bdns)\\.(core)\\.(rocks)", "{2}.{1}.{3}", "NoAnswer", "name", "(core)\\.(bdns)\\.(rocks)", "{2}.{1}.{3}"}, true, nil}, + {[]string{"name", "regex", "(cdns)\\.(core)\\.(rocks)", "{2}.{1}.{3}", "answer", "ttl", "(core)\\.(cdns)\\.(rocks)", "{2}.{1}.{3}"}, true, nil}, + {[]string{"name", "regex", "(ddns)\\.(core)\\.(rocks)", "{2}.{1}.{3}", "answer", "name", "\xecore\\.(ddns)\\.(rocks)", "{2}.{1}.{3}"}, true, nil}, + {[]string{"name", "regex", "\xedns\\.(core)\\.(rocks)", "{2}.{1}.{3}", "answer", "name", "(core)\\.(edns)\\.(rocks)", "{2}.{1}.{3}"}, true, nil}, + {[]string{"name", "substring", "fcore.dns.rocks", "dns.fcore.rocks", "answer", "name", "(fcore)\\.(dns)\\.(rocks)", "{2}.{1}.{3}"}, false, reflect.TypeOf(&substringNameRule{})}, + {[]string{"name", "substring", "a.com", "b.com", "c.com"}, true, nil}, + {[]string{"type"}, true, nil}, + {[]string{"type", "a"}, true, nil}, + {[]string{"type", "any", "a", "a"}, true, nil}, + {[]string{"type", "any", "a"}, false, reflect.TypeOf(&typeRule{})}, + {[]string{"type", "XY", "WV"}, true, nil}, + {[]string{"type", "ANY", "WV"}, true, nil}, + {[]string{"class"}, true, nil}, + {[]string{"class", "IN"}, true, nil}, + {[]string{"class", "ch", "in", "in"}, true, nil}, + {[]string{"class", "ch", "in"}, false, reflect.TypeOf(&classRule{})}, + {[]string{"class", "XY", "WV"}, true, nil}, + {[]string{"class", "IN", "WV"}, true, nil}, + {[]string{"edns0"}, true, nil}, + {[]string{"edns0", "local"}, true, nil}, + {[]string{"edns0", "local", "set"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee"}, true, nil}, + {[]string{"edns0", "local", "set", "65518", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"edns0", "local", "foo", "0xffee", "abcdefg"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "0xabcdefg"}, true, nil}, + {[]string{"edns0", "nsid", "set", "junk"}, true, nil}, + {[]string{"edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "append"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "replace"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"edns0", "nsid", "foo"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "append", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "append", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{dummy}"}, true, nil}, + {[]string{"edns0", "local", "replace", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "local", "replace", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"edns0", "subnet", "set", "-1", "56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "-56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "33", "56"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "129"}, true, nil}, + {[]string{"edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"unknown-action", "name", "a.com", "b.com"}, true, nil}, + {[]string{"stop", "name", "a.com", "b.com"}, false, reflect.TypeOf(&exactNameRule{})}, + {[]string{"continue", "name", "a.com", "b.com"}, false, reflect.TypeOf(&exactNameRule{})}, + {[]string{"unknown-action", "type", "any", "a"}, true, nil}, + {[]string{"stop", "type", "any", "a"}, false, reflect.TypeOf(&typeRule{})}, + {[]string{"continue", "type", "any", "a"}, false, reflect.TypeOf(&typeRule{})}, + {[]string{"unknown-action", "class", "ch", "in"}, true, nil}, + {[]string{"stop", "class", "ch", "in"}, false, reflect.TypeOf(&classRule{})}, + {[]string{"continue", "class", "ch", "in"}, false, reflect.TypeOf(&classRule{})}, + {[]string{"unknown-action", "edns0", "local", "set", "0xffee", "abcedef"}, true, nil}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "abcdefg"}, false, reflect.TypeOf(&edns0LocalRule{})}, + {[]string{"unknown-action", "edns0", "nsid", "set"}, true, nil}, + {[]string{"stop", "edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"continue", "edns0", "nsid", "set"}, false, reflect.TypeOf(&edns0NsidRule{})}, + {[]string{"unknown-action", "edns0", "local", "set", "0xffee", "{qname}"}, true, nil}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"stop", "edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{qname}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{qtype}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{client_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{client_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{protocol}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{server_ip}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"continue", "edns0", "local", "set", "0xffee", "{server_port}"}, false, reflect.TypeOf(&edns0VariableRule{})}, + {[]string{"unknown-action", "edns0", "subnet", "set", "24", "64"}, true, nil}, + {[]string{"stop", "edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"stop", "edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"stop", "edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"continue", "edns0", "subnet", "set", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"continue", "edns0", "subnet", "append", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + {[]string{"continue", "edns0", "subnet", "replace", "24", "56"}, false, reflect.TypeOf(&edns0SubnetRule{})}, + } + + for i, tc := range tests { + r, err := newRule(tc.args...) + if err == nil && tc.shouldError { + t.Errorf("Test %d: expected error but got success", i) + } else if err != nil && !tc.shouldError { + t.Errorf("Test %d: expected success but got error: %s", i, err) + } + + if !tc.shouldError && reflect.TypeOf(r) != tc.expType { + t.Errorf("Test %d: expected %q but got %q", i, tc.expType, r) + } + } +} + +func TestRewriteDefaultRevertPolicy(t *testing.T) { + rules := []Rule{} + + r, _ := newNameRule("stop", "prefix", "prefix", "to") + rules = append(rules, r) + r, _ = newNameRule("stop", "suffix", ".suffix.", ".nl.") + rules = append(rules, r) + r, _ = newNameRule("stop", "substring", "from.substring", "to") + rules = append(rules, r) + r, _ = newNameRule("stop", "regex", "(f.*m)\\.regex\\.(nl)", "to.{2}") + rules = append(rules, r) + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + // use production (default) RevertPolicy + } + + tests := []struct { + from string + fromT uint16 + fromC uint16 + to string + toT uint16 + toC uint16 + }{ + {"prefix.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"to.suffix.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"from.substring.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"from.regex.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromT) + m.Question[0].Qclass = tc.fromC + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + + if resp.Question[0].Name != tc.from { + t.Errorf("Test %d: Expected Name in Question to be %q but was %q", i, tc.from, resp.Question[0].Name) + } + + if resp.Answer[0].Header().Name != tc.to { + t.Errorf("Test %d: Expected Name in Answer to be %q but was %q", i, tc.to, resp.Answer[0].Header().Name) + } + } +} + +func TestRewrite(t *testing.T) { + rules := []Rule{} + r, _ := newNameRule("stop", "from.nl.", "to.nl.") + rules = append(rules, r) + r, _ = newNameRule("stop", "regex", "(core)\\.(dns)\\.(rocks)\\.(nl)", "{2}.{1}.{3}.{4}", "answer", "name", "(dns)\\.(core)\\.(rocks)\\.(nl)", "{2}.{1}.{3}.{4}") + rules = append(rules, r) + r, _ = newNameRule("stop", "exact", "from.exact.nl.", "to.nl.") + rules = append(rules, r) + r, _ = newNameRule("stop", "prefix", "prefix", "to") + rules = append(rules, r) + r, _ = newNameRule("stop", "suffix", ".suffix.", ".nl.") + rules = append(rules, r) + r, _ = newNameRule("stop", "substring", "from.substring", "to") + rules = append(rules, r) + r, _ = newNameRule("stop", "regex", "(f.*m)\\.regex\\.(nl)", "to.{2}") + rules = append(rules, r) + r, _ = newNameRule("continue", "regex", "consul\\.(rocks)", "core.dns.{1}") + rules = append(rules, r) + r, _ = newNameRule("stop", "core.dns.rocks", "to.nl.") + rules = append(rules, r) + r, _ = newClassRule("continue", "HS", "CH") + rules = append(rules, r) + r, _ = newClassRule("stop", "CH", "IN") + rules = append(rules, r) + r, _ = newTypeRule("stop", "ANY", "HINFO") + rules = append(rules, r) + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + RevertPolicy: NoRevertPolicy(), + } + + tests := []struct { + from string + fromT uint16 + fromC uint16 + to string + toT uint16 + toC uint16 + }{ + {"from.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"a.nl.", dns.TypeA, dns.ClassINET, "a.nl.", dns.TypeA, dns.ClassINET}, + {"a.nl.", dns.TypeA, dns.ClassCHAOS, "a.nl.", dns.TypeA, dns.ClassINET}, + {"a.nl.", dns.TypeANY, dns.ClassINET, "a.nl.", dns.TypeHINFO, dns.ClassINET}, + // name is rewritten, type is not. + {"from.nl.", dns.TypeANY, dns.ClassINET, "to.nl.", dns.TypeANY, dns.ClassINET}, + {"from.exact.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"prefix.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"to.suffix.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"from.substring.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"from.regex.nl.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + {"consul.rocks.", dns.TypeA, dns.ClassINET, "to.nl.", dns.TypeA, dns.ClassINET}, + // name is not, type is, but class is, because class is the 2nd rule. + {"a.nl.", dns.TypeANY, dns.ClassCHAOS, "a.nl.", dns.TypeANY, dns.ClassINET}, + // class gets rewritten twice because of continue/stop logic: HS to CH, CH to IN + {"a.nl.", dns.TypeANY, 4, "a.nl.", dns.TypeANY, dns.ClassINET}, + {"core.dns.rocks.nl.", dns.TypeA, dns.ClassINET, "dns.core.rocks.nl.", dns.TypeA, dns.ClassINET}, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromT) + m.Question[0].Qclass = tc.fromC + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + if resp.Question[0].Name != tc.to { + t.Errorf("Test %d: Expected Name to be %q but was %q", i, tc.to, resp.Question[0].Name) + } + if resp.Question[0].Qtype != tc.toT { + t.Errorf("Test %d: Expected Type to be '%d' but was '%d'", i, tc.toT, resp.Question[0].Qtype) + } + if resp.Question[0].Qclass != tc.toC { + t.Errorf("Test %d: Expected Class to be '%d' but was '%d'", i, tc.toC, resp.Question[0].Qclass) + } + if tc.fromT == dns.TypeA && tc.toT == dns.TypeA { + if len(resp.Answer) > 0 { + if resp.Answer[0].(*dns.A).Hdr.Name != tc.to { + t.Errorf("Test %d: Expected Answer Name to be %q but was %q", i, tc.to, resp.Answer[0].(*dns.A).Hdr.Name) + } + } + } + } +} + +func TestRewriteEDNS0Local(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + RevertPolicy: NoRevertPolicy(), + } + + tests := []struct { + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + doBool bool + }{ + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "0xabcdef"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0xab, 0xcd, 0xef}}}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "append", "0xffee", "abcdefghijklmnop"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdefghijklmnop")}}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "replace", "0xffee", "abcdefghijklmnop"}, + []dns.EDNS0{}, + true, + }, + { + []dns.EDNS0{}, + []string{"nsid", "set"}, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + false, + }, + { + []dns.EDNS0{}, + []string{"nsid", "append"}, + []dns.EDNS0{&dns.EDNS0_NSID{Code: dns.EDNS0NSID, Nsid: ""}}, + true, + }, + { + []dns.EDNS0{}, + []string{"nsid", "replace"}, + []dns.EDNS0{}, + true, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + + r, err := newEdns0Rule("stop", tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + o.SetDo(tc.doBool) + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if o.Do() != tc.doBool { + t.Errorf("Test %d: Expected %v but got %v", i, tc.doBool, o.Do()) + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} + +func TestEdns0LocalMultiRule(t *testing.T) { + rules := []Rule{} + r, _ := newEdns0Rule("stop", "local", "replace", "0xffee", "abcdef") + rules = append(rules, r) + r, _ = newEdns0Rule("stop", "local", "set", "0xffee", "fedcba") + rules = append(rules, r) + + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + RevertPolicy: NoRevertPolicy(), + } + + tests := []struct { + fromOpts []dns.EDNS0 + toOpts []dns.EDNS0 + }{ + { + nil, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("fedcba")}}, + }, + { + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("foobar")}}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("abcdef")}}, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + m.Question[0].Qclass = dns.ClassINET + if tc.fromOpts != nil { + o := m.IsEdns0() + if o == nil { + m.SetEdns0(4096, true) + o = m.IsEdns0() + } + o.Option = append(o.Option, tc.fromOpts...) + } + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} + +func optsEqual(a, b []dns.EDNS0) bool { + if len(a) != len(b) { + return false + } + for i := range a { + switch aa := a[i].(type) { + case *dns.EDNS0_LOCAL: + if bb, ok := b[i].(*dns.EDNS0_LOCAL); ok { + if aa.Code != bb.Code { + return false + } + if !bytes.Equal(aa.Data, bb.Data) { + return false + } + } else { + return false + } + case *dns.EDNS0_NSID: + if bb, ok := b[i].(*dns.EDNS0_NSID); ok { + if aa.Nsid != bb.Nsid { + return false + } + } else { + return false + } + case *dns.EDNS0_SUBNET: + if bb, ok := b[i].(*dns.EDNS0_SUBNET); ok { + if aa.Code != bb.Code { + return false + } + if aa.Family != bb.Family { + return false + } + if aa.SourceNetmask != bb.SourceNetmask { + return false + } + if aa.SourceScope != bb.SourceScope { + return false + } + if !aa.Address.Equal(bb.Address) { + return false + } + } else { + return false + } + + default: + return false + } + } + return true +} + +type testProvider map[string]metadata.Func + +func (tp testProvider) Metadata(ctx context.Context, state request.Request) context.Context { + for k, v := range tp { + metadata.SetValueFunc(ctx, k, v) + } + return ctx +} + +func TestRewriteEDNS0LocalVariable(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + RevertPolicy: NoRevertPolicy(), + } + + expectedMetadata := []metadata.Provider{ + testProvider{"test/label": func() string { return "my-value" }}, + testProvider{"test/empty": func() string { return "" }}, + } + + meta := metadata.Metadata{ + Zones: []string{"."}, + Providers: expectedMetadata, + Next: &rw, + } + + // test.ResponseWriter has the following values: + // The remote will always be 10.240.0.1 and port 40212. + // The local address is always 127.0.0.1 and port 53. + + tests := []struct { + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + doBool bool + }{ + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{qname}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("example.com.")}}, + true, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{qtype}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x01}}}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{client_ip}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x0A, 0xF0, 0x00, 0x01}}}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{client_port}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x9D, 0x14}}}, + true, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{protocol}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("udp")}}, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{server_port}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x00, 0x35}}}, + true, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{server_ip}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte{0x7F, 0x00, 0x00, 0x01}}}, + true, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{test/label}"}, + []dns.EDNS0{&dns.EDNS0_LOCAL{Code: 0xffee, Data: []byte("my-value")}}, + true, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{test/empty}"}, + nil, + false, + }, + { + []dns.EDNS0{}, + []string{"local", "set", "0xffee", "{test/does-not-exist}"}, + nil, + false, + }, + } + + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + + r, err := newEdns0Rule("stop", tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + ctx := meta.Collect(context.TODO(), request.Request{W: rec, Req: m}) + meta.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + if o == nil { + if tc.toOpts != nil { + t.Errorf("Test %d: EDNS0 options not set", i) + } + continue + } + o.SetDo(tc.doBool) + if o.Do() != tc.doBool { + t.Errorf("Test %d: Expected %v but got %v", i, tc.doBool, o.Do()) + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} + +func TestRewriteEDNS0Subnet(t *testing.T) { + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + RevertPolicy: NoRevertPolicy(), + } + + tests := []struct { + writer dns.ResponseWriter + fromOpts []dns.EDNS0 + args []string + toOpts []dns.EDNS0 + doBool bool + }{ + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x18, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x00}, + }}, + true, + }, + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "32", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x20, + SourceScope: 0x0, + Address: []byte{0x0A, 0xF0, 0x00, 0x01}, + }}, + false, + }, + { + &test.ResponseWriter{}, + []dns.EDNS0{}, + []string{"subnet", "set", "0", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x1, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00}, + }}, + false, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "56"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x38, + SourceScope: 0x0, + Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }}, + true, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "128"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x80, + SourceScope: 0x0, + Address: []byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x42, 0x00, 0xff, 0xfe, 0xca, 0x4c, 0x65}, + }}, + false, + }, + { + &test.ResponseWriter6{}, + []dns.EDNS0{}, + []string{"subnet", "set", "24", "0"}, + []dns.EDNS0{&dns.EDNS0_SUBNET{Code: 0x8, + Family: 0x2, + SourceNetmask: 0x0, + SourceScope: 0x0, + Address: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + }}, + true, + }, + } + + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + + r, err := newEdns0Rule("stop", tc.args...) + if err != nil { + t.Errorf("Error creating test rule: %s", err) + continue + } + rw.Rules = []Rule{r} + rec := dnstest.NewRecorder(tc.writer) + rw.ServeDNS(ctx, rec, m) + + resp := rec.Msg + o := resp.IsEdns0() + o.SetDo(tc.doBool) + if o == nil { + t.Errorf("Test %d: EDNS0 options not set", i) + continue + } + if o.Do() != tc.doBool { + t.Errorf("Test %d: Expected %v but got %v", i, tc.doBool, o.Do()) + } + if !optsEqual(o.Option, tc.toOpts) { + t.Errorf("Test %d: Expected %v but got %v", i, tc.toOpts, o) + } + } +} diff --git a/plugin/rewrite/setup.go b/plugin/rewrite/setup.go new file mode 100644 index 0000000..36f31dc --- /dev/null +++ b/plugin/rewrite/setup.go @@ -0,0 +1,42 @@ +package rewrite + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("rewrite", setup) } + +func setup(c *caddy.Controller) error { + rewrites, err := rewriteParse(c) + if err != nil { + return plugin.Error("rewrite", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Rewrite{Next: next, Rules: rewrites} + }) + + return nil +} + +func rewriteParse(c *caddy.Controller) ([]Rule, error) { + var rules []Rule + + for c.Next() { + args := c.RemainingArgs() + if len(args) < 2 { + // Handles rules out of nested instructions, i.e. the ones enclosed in curly brackets + for c.NextBlock() { + args = append(args, c.Val()) + } + } + rule, err := newRule(args...) + if err != nil { + return nil, err + } + rules = append(rules, rule) + } + return rules, nil +} diff --git a/plugin/rewrite/setup_test.go b/plugin/rewrite/setup_test.go new file mode 100644 index 0000000..88d332f --- /dev/null +++ b/plugin/rewrite/setup_test.go @@ -0,0 +1,51 @@ +package rewrite + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestParse(t *testing.T) { + tests := []struct { + inputFileRules string + shouldErr bool + errContains string + }{ + // parse errors + {`rewrite`, true, ""}, + {`rewrite name`, true, ""}, + {`rewrite name a.com b.com`, false, ""}, + {`rewrite stop { + name regex foo bar + answer name bar foo +}`, false, ""}, + {`rewrite stop name regex foo bar answer name bar foo`, false, ""}, + {`rewrite stop { + name regex foo bar + answer name bar foo + name baz +}`, true, "2 arguments required"}, + {`rewrite stop { + answer name bar foo + name regex foo bar +}`, true, "must begin with a name rule"}, + {`rewrite stop`, true, ""}, + {`rewrite continue`, true, ""}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + _, err := rewriteParse(c) + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error\n---\n%s", i, test.inputFileRules) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'\n---\n%s", i, err, test.inputFileRules) + } + + if err != nil && test.errContains != "" && !strings.Contains(err.Error(), test.errContains) { + t.Errorf("Test %d got wrong error for invalid response rewrite: '%v'\n---\n%s", i, err.Error(), test.inputFileRules) + } + } +} diff --git a/plugin/rewrite/ttl.go b/plugin/rewrite/ttl.go new file mode 100644 index 0000000..5430fc9 --- /dev/null +++ b/plugin/rewrite/ttl.go @@ -0,0 +1,205 @@ +package rewrite + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +type ttlResponseRule struct { + minTTL uint32 + maxTTL uint32 +} + +func (r *ttlResponseRule) RewriteResponse(res *dns.Msg, rr dns.RR) { + if rr.Header().Ttl < r.minTTL { + rr.Header().Ttl = r.minTTL + } else if rr.Header().Ttl > r.maxTTL { + rr.Header().Ttl = r.maxTTL + } +} + +type ttlRuleBase struct { + nextAction string + response ttlResponseRule +} + +func newTTLRuleBase(nextAction string, minTtl, maxTtl uint32) ttlRuleBase { + return ttlRuleBase{ + nextAction: nextAction, + response: ttlResponseRule{minTTL: minTtl, maxTTL: maxTtl}, + } +} + +func (rule *ttlRuleBase) responseRule(match bool) (ResponseRules, Result) { + if match { + return ResponseRules{&rule.response}, RewriteDone + } + return nil, RewriteIgnored +} + +// Mode returns the processing nextAction +func (rule *ttlRuleBase) Mode() string { return rule.nextAction } + +type exactTTLRule struct { + ttlRuleBase + From string +} + +type prefixTTLRule struct { + ttlRuleBase + Prefix string +} + +type suffixTTLRule struct { + ttlRuleBase + Suffix string +} + +type substringTTLRule struct { + ttlRuleBase + Substring string +} + +type regexTTLRule struct { + ttlRuleBase + Pattern *regexp.Regexp +} + +// Rewrite rewrites the current request based upon exact match of the name +// in the question section of the request. +func (rule *exactTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(rule.From == state.Name()) +} + +// Rewrite rewrites the current request when the name begins with the matching string. +func (rule *prefixTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(strings.HasPrefix(state.Name(), rule.Prefix)) +} + +// Rewrite rewrites the current request when the name ends with the matching string. +func (rule *suffixTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(strings.HasSuffix(state.Name(), rule.Suffix)) +} + +// Rewrite rewrites the current request based upon partial match of the +// name in the question section of the request. +func (rule *substringTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(strings.Contains(state.Name(), rule.Substring)) +} + +// Rewrite rewrites the current request when the name in the question +// section of the request matches a regular expression. +func (rule *regexTTLRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + return rule.responseRule(len(rule.Pattern.FindStringSubmatch(state.Name())) != 0) +} + +// newTTLRule creates a name matching rule based on exact, partial, or regex match +func newTTLRule(nextAction string, args ...string) (Rule, error) { + if len(args) < 2 { + return nil, fmt.Errorf("too few (%d) arguments for a ttl rule", len(args)) + } + var s string + if len(args) == 2 { + s = args[1] + } + if len(args) == 3 { + s = args[2] + } + minTtl, maxTtl, valid := isValidTTL(s) + if !valid { + return nil, fmt.Errorf("invalid TTL '%s' for a ttl rule", s) + } + if len(args) == 3 { + switch strings.ToLower(args[0]) { + case ExactMatch: + return &exactTTLRule{ + newTTLRuleBase(nextAction, minTtl, maxTtl), + plugin.Name(args[1]).Normalize(), + }, nil + case PrefixMatch: + return &prefixTTLRule{ + newTTLRuleBase(nextAction, minTtl, maxTtl), + plugin.Name(args[1]).Normalize(), + }, nil + case SuffixMatch: + return &suffixTTLRule{ + newTTLRuleBase(nextAction, minTtl, maxTtl), + plugin.Name(args[1]).Normalize(), + }, nil + case SubstringMatch: + return &substringTTLRule{ + newTTLRuleBase(nextAction, minTtl, maxTtl), + plugin.Name(args[1]).Normalize(), + }, nil + case RegexMatch: + regexPattern, err := regexp.Compile(args[1]) + if err != nil { + return nil, fmt.Errorf("invalid regex pattern in a ttl rule: %s", args[1]) + } + return ®exTTLRule{ + newTTLRuleBase(nextAction, minTtl, maxTtl), + regexPattern, + }, nil + default: + return nil, fmt.Errorf("ttl rule supports only exact, prefix, suffix, substring, and regex name matching") + } + } + if len(args) > 3 { + return nil, fmt.Errorf("many few arguments for a ttl rule") + } + return &exactTTLRule{ + newTTLRuleBase(nextAction, minTtl, maxTtl), + plugin.Name(args[0]).Normalize(), + }, nil +} + +// validTTL returns true if v is valid TTL value. +func isValidTTL(v string) (uint32, uint32, bool) { + s := strings.Split(v, "-") + if len(s) == 1 { + i, err := strconv.ParseUint(s[0], 10, 32) + if err != nil { + return 0, 0, false + } + return uint32(i), uint32(i), true + } + if len(s) == 2 { + var min, max uint64 + var err error + if s[0] == "" { + min = 0 + } else { + min, err = strconv.ParseUint(s[0], 10, 32) + if err != nil { + return 0, 0, false + } + } + if s[1] == "" { + if s[0] == "" { + // explicitly reject ttl directive "-" that would otherwise be interpreted + // as 0-2147483647 which is pretty useless + return 0, 0, false + } + max = 2147483647 + } else { + max, err = strconv.ParseUint(s[1], 10, 32) + if err != nil { + return 0, 0, false + } + } + if min > max { + // reject invalid range + return 0, 0, false + } + return uint32(min), uint32(max), true + } + return 0, 0, false +} diff --git a/plugin/rewrite/ttl_test.go b/plugin/rewrite/ttl_test.go new file mode 100644 index 0000000..40fa097 --- /dev/null +++ b/plugin/rewrite/ttl_test.go @@ -0,0 +1,157 @@ +package rewrite + +import ( + "context" + "reflect" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestNewTTLRule(t *testing.T) { + tests := []struct { + next string + args []string + expectedFail bool + }{ + {"stop", []string{"srv1.coredns.rocks", "10"}, false}, + {"stop", []string{"exact", "srv1.coredns.rocks", "15"}, false}, + {"stop", []string{"prefix", "coredns.rocks", "20"}, false}, + {"stop", []string{"suffix", "srv1", "25"}, false}, + {"stop", []string{"substring", "coredns", "30"}, false}, + {"stop", []string{"regex", `(srv1)\.(coredns)\.(rocks)`, "35"}, false}, + {"continue", []string{"srv1.coredns.rocks", "10"}, false}, + {"continue", []string{"exact", "srv1.coredns.rocks", "15"}, false}, + {"continue", []string{"prefix", "coredns.rocks", "20"}, false}, + {"continue", []string{"suffix", "srv1", "25"}, false}, + {"continue", []string{"substring", "coredns", "30"}, false}, + {"continue", []string{"regex", `(srv1)\.(coredns)\.(rocks)`, "35"}, false}, + {"stop", []string{"srv1.coredns.rocks", "12345678901234567890"}, true}, + {"stop", []string{"srv1.coredns.rocks", "coredns.rocks"}, true}, + {"stop", []string{"srv1.coredns.rocks", "#1"}, true}, + {"stop", []string{"range.coredns.rocks", "1-2"}, false}, + {"stop", []string{"ceil.coredns.rocks", "-2"}, false}, + {"stop", []string{"floor.coredns.rocks", "1-"}, false}, + {"stop", []string{"range.coredns.rocks", "2-2"}, false}, + {"stop", []string{"invalid.coredns.rocks", "-"}, true}, + {"stop", []string{"invalid.coredns.rocks", "2-1"}, true}, + {"stop", []string{"invalid.coredns.rocks", "5-10-20"}, true}, + } + for i, tc := range tests { + failed := false + rule, err := newTTLRule(tc.next, tc.args...) + if err != nil { + failed = true + } + if !failed && !tc.expectedFail { + continue + } + if failed && tc.expectedFail { + continue + } + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v", i, tc.expectedFail, failed, tc.next, tc.args, rule) + } + for i, tc := range tests { + failed := false + tc.args = append([]string{tc.next, "ttl"}, tc.args...) + rule, err := newRule(tc.args...) + if err != nil { + failed = true + } + if !failed && !tc.expectedFail { + continue + } + if failed && tc.expectedFail { + continue + } + t.Fatalf("Test %d: FAIL, expected fail=%t, but received fail=%t: (%s) %s, rule=%v", i, tc.expectedFail, failed, tc.next, tc.args, rule) + } +} + +func TestTtlRewrite(t *testing.T) { + rules := []Rule{} + ruleset := []struct { + args []string + expectedType reflect.Type + }{ + {[]string{"stop", "ttl", "srv1.coredns.rocks", "1"}, reflect.TypeOf(&exactTTLRule{})}, + {[]string{"stop", "ttl", "exact", "srv15.coredns.rocks", "15"}, reflect.TypeOf(&exactTTLRule{})}, + {[]string{"stop", "ttl", "prefix", "srv30", "30"}, reflect.TypeOf(&prefixTTLRule{})}, + {[]string{"stop", "ttl", "suffix", "45.coredns.rocks", "45"}, reflect.TypeOf(&suffixTTLRule{})}, + {[]string{"stop", "ttl", "substring", "rv50", "50"}, reflect.TypeOf(&substringTTLRule{})}, + {[]string{"stop", "ttl", "regex", `(srv10)\.(coredns)\.(rocks)`, "10"}, reflect.TypeOf(®exTTLRule{})}, + {[]string{"stop", "ttl", "regex", `(srv20)\.(coredns)\.(rocks)`, "20"}, reflect.TypeOf(®exTTLRule{})}, + {[]string{"stop", "ttl", "range.example.com.", "30-300"}, reflect.TypeOf(&exactTTLRule{})}, + {[]string{"stop", "ttl", "ceil.example.com.", "-11"}, reflect.TypeOf(&exactTTLRule{})}, + {[]string{"stop", "ttl", "floor.example.com.", "5-"}, reflect.TypeOf(&exactTTLRule{})}, + } + for i, r := range ruleset { + rule, err := newRule(r.args...) + if err != nil { + t.Fatalf("Rule %d: FAIL, %s: %s", i, r.args, err) + } + if reflect.TypeOf(rule) != r.expectedType { + t.Fatalf("Rule %d: FAIL, %s: rule type mismatch, expected %q, but got %q", i, r.args, r.expectedType, rule) + } + rules = append(rules, rule) + } + doTTLTests(rules, t) +} + +func doTTLTests(rules []Rule, t *testing.T) { + tests := []struct { + from string + fromType uint16 + answer []dns.RR + ttl uint32 + noRewrite bool + }{ + {"srv1.coredns.rocks.", dns.TypeA, []dns.RR{test.A("srv1.coredns.rocks. 5 IN A 10.0.0.1")}, 1, false}, + {"srv15.coredns.rocks.", dns.TypeA, []dns.RR{test.A("srv15.coredns.rocks. 5 IN A 10.0.0.15")}, 15, false}, + {"srv30.coredns.rocks.", dns.TypeA, []dns.RR{test.A("srv30.coredns.rocks. 5 IN A 10.0.0.30")}, 30, false}, + {"srv45.coredns.rocks.", dns.TypeA, []dns.RR{test.A("srv45.coredns.rocks. 5 IN A 10.0.0.45")}, 45, false}, + {"srv50.coredns.rocks.", dns.TypeA, []dns.RR{test.A("srv50.coredns.rocks. 5 IN A 10.0.0.50")}, 50, false}, + {"srv10.coredns.rocks.", dns.TypeA, []dns.RR{test.A("srv10.coredns.rocks. 5 IN A 10.0.0.10")}, 10, false}, + {"xmpp.coredns.rocks.", dns.TypeSRV, []dns.RR{test.SRV("xmpp.coredns.rocks. 5 IN SRV 0 100 100 srvxmpp.coredns.rocks.")}, 5, true}, + {"srv15.coredns.rocks.", dns.TypeHINFO, []dns.RR{test.HINFO("srv15.coredns.rocks. 5 HINFO INTEL-64 \"RHEL 7.5\"")}, 15, false}, + {"srv20.coredns.rocks.", dns.TypeA, []dns.RR{ + test.A("srv20.coredns.rocks. 5 IN A 10.0.0.22"), + test.A("srv20.coredns.rocks. 5 IN A 10.0.0.23"), + }, 20, false}, + {"range.example.com.", dns.TypeA, []dns.RR{test.A("range.example.com. 5 IN A 10.0.0.1")}, 30, false}, + {"range.example.com.", dns.TypeA, []dns.RR{test.A("range.example.com. 55 IN A 10.0.0.1")}, 55, false}, + {"range.example.com.", dns.TypeA, []dns.RR{test.A("range.example.com. 500 IN A 10.0.0.1")}, 300, false}, + {"ceil.example.com.", dns.TypeA, []dns.RR{test.A("ceil.example.com. 5 IN A 10.0.0.1")}, 5, false}, + {"ceil.example.com.", dns.TypeA, []dns.RR{test.A("ceil.example.com. 15 IN A 10.0.0.1")}, 11, false}, + {"floor.example.com.", dns.TypeA, []dns.RR{test.A("floor.example.com. 0 IN A 10.0.0.1")}, 5, false}, + {"floor.example.com.", dns.TypeA, []dns.RR{test.A("floor.example.com. 30 IN A 10.0.0.1")}, 30, false}, + } + ctx := context.TODO() + for i, tc := range tests { + m := new(dns.Msg) + m.SetQuestion(tc.from, tc.fromType) + m.Question[0].Qclass = dns.ClassINET + m.Answer = tc.answer + rw := Rewrite{ + Next: plugin.HandlerFunc(msgPrinter), + Rules: rules, + } + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + rw.ServeDNS(ctx, rec, m) + resp := rec.Msg + if len(resp.Answer) == 0 { + t.Errorf("Test %d: FAIL %s (%d) Expected valid response but received %q", i, tc.from, tc.fromType, resp) + continue + } + for _, a := range resp.Answer { + if a.Header().Ttl != tc.ttl { + t.Errorf("Test %d: FAIL %s (%d) Expected TTL to be %d but was %d", i, tc.from, tc.fromType, tc.ttl, a.Header().Ttl) + break + } + } + } +} diff --git a/plugin/rewrite/type.go b/plugin/rewrite/type.go new file mode 100644 index 0000000..63796e9 --- /dev/null +++ b/plugin/rewrite/type.go @@ -0,0 +1,45 @@ +// Package rewrite is a plugin for rewriting requests internally to something different. +package rewrite + +import ( + "context" + "fmt" + "strings" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// typeRule is a type rewrite rule. +type typeRule struct { + fromType uint16 + toType uint16 + nextAction string +} + +func newTypeRule(nextAction string, args ...string) (Rule, error) { + var from, to uint16 + var ok bool + if from, ok = dns.StringToType[strings.ToUpper(args[0])]; !ok { + return nil, fmt.Errorf("invalid type %q", strings.ToUpper(args[0])) + } + if to, ok = dns.StringToType[strings.ToUpper(args[1])]; !ok { + return nil, fmt.Errorf("invalid type %q", strings.ToUpper(args[1])) + } + return &typeRule{from, to, nextAction}, nil +} + +// Rewrite rewrites the current request. +func (rule *typeRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) { + if rule.fromType > 0 && rule.toType > 0 { + if state.QType() == rule.fromType { + state.Req.Question[0].Qtype = rule.toType + return nil, RewriteDone + } + } + return nil, RewriteIgnored +} + +// Mode returns the processing mode. +func (rule *typeRule) Mode() string { return rule.nextAction } diff --git a/plugin/rewrite/wire.go b/plugin/rewrite/wire.go new file mode 100644 index 0000000..df25f7f --- /dev/null +++ b/plugin/rewrite/wire.go @@ -0,0 +1,35 @@ +package rewrite + +import ( + "encoding/binary" + "fmt" + "net" + "strconv" +) + +// ipToWire writes IP address to wire/binary format, 4 or 16 bytes depends on IPV4 or IPV6. +func ipToWire(family int, ipAddr string) ([]byte, error) { + switch family { + case 1: + return net.ParseIP(ipAddr).To4(), nil + case 2: + return net.ParseIP(ipAddr).To16(), nil + } + return nil, fmt.Errorf("invalid IP address family (i.e. version) %d", family) +} + +// uint16ToWire writes unit16 to wire/binary format +func uint16ToWire(data uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, data) + return buf +} + +// portToWire writes port to wire/binary format, 2 bytes +func portToWire(portStr string) ([]byte, error) { + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + return uint16ToWire(uint16(port)), nil +} diff --git a/plugin/root/README.md b/plugin/root/README.md new file mode 100644 index 0000000..33ea89e --- /dev/null +++ b/plugin/root/README.md @@ -0,0 +1,54 @@ +# root + +## Name + +*root* - simply specifies the root of where to find files. + +## Description + +The default root is the current working directory of CoreDNS. The *root* plugin allows you to change +this. A relative root path is relative to the current working directory. +**NOTE: The *root* directory is NOT currently supported by all plugins.** +Currently the following plugins respect the *root* plugin configuration: + +* file +* tls + +This plugin can only be used once per Server Block. + +## Syntax + +~~~ txt +root PATH +~~~ + +**PATH** is the directory to set as CoreDNS' root. + +## Examples + +Serve zone data (when the *file* plugin is used) from `/etc/coredns/zones`: + +~~~ corefile +. { + root /etc/coredns/zones +} +~~~ + +When you use the *root* and *tls* plugin together, your cert and key should also be placed in the *root* directory. +The example below will look for `/config/cert.pem` and `/config/key.pem` + +~~~ txt +tls://example.com:853 { + root /config + tls cert.pem key.pem + whoami +} +~~~ + +## Bugs + +**NOTE: The *root* directory is NOT currently supported by all plugins.** +Currently the following plugins respect the *root* plugin configuration: + +* file +* tls diff --git a/plugin/root/log_test.go b/plugin/root/log_test.go new file mode 100644 index 0000000..f63caac --- /dev/null +++ b/plugin/root/log_test.go @@ -0,0 +1,5 @@ +package root + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/root/root.go b/plugin/root/root.go new file mode 100644 index 0000000..b8bdf94 --- /dev/null +++ b/plugin/root/root.go @@ -0,0 +1,39 @@ +package root + +import ( + "os" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +var log = clog.NewWithPlugin("root") + +func init() { plugin.Register("root", setup) } + +func setup(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + for c.Next() { + if !c.NextArg() { + return plugin.Error("root", c.ArgErr()) + } + config.Root = c.Val() + } + + // Check if root path exists + _, err := os.Stat(config.Root) + if err != nil { + if os.IsNotExist(err) { + // Allow this, because the folder might appear later. + // But make sure the user knows! + log.Warningf("Root path does not exist: %s", config.Root) + } else { + return plugin.Error("root", c.Errf("unable to access root path '%s': %v", config.Root, err)) + } + } + + return nil +} diff --git a/plugin/root/root_test.go b/plugin/root/root_test.go new file mode 100644 index 0000000..27bdf84 --- /dev/null +++ b/plugin/root/root_test.go @@ -0,0 +1,102 @@ +package root + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestRoot(t *testing.T) { + // Predefined error substrings + parseErrContent := "Error during parsing:" + unableToAccessErrContent := "unable to access root path" + + existingDirPath, err := getTempDirPath() + if err != nil { + t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) + } + + nonExistingDir := filepath.Join(existingDirPath, "highly_unlikely_to_exist_dir") + + existingFile, err := os.CreateTemp("", "root_test") + if err != nil { + t.Fatalf("BeforeTest: Failed to create temp file for testing! Error was: %v", err) + } + defer func() { + existingFile.Close() + os.Remove(existingFile.Name()) + }() + + inaccessiblePath := getInaccessiblePath(existingFile.Name()) + + tests := []struct { + input string + shouldErr bool + expectedRoot string // expected root, set to the controller. Empty for negative cases. + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + { + fmt.Sprintf(`root %s`, nonExistingDir), false, nonExistingDir, "", + }, + { + fmt.Sprintf(`root %s`, existingDirPath), false, existingDirPath, "", + }, + // negative + { + `root `, true, "", parseErrContent, + }, + { + fmt.Sprintf(`root %s`, inaccessiblePath), true, "", unableToAccessErrContent, + }, + { + fmt.Sprintf(`root { + %s + }`, existingDirPath), true, "", parseErrContent, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + cfg := dnsserver.GetConfig(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + + // check root only if we are in a positive test. + if !test.shouldErr && test.expectedRoot != cfg.Root { + t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, cfg.Root) + } + } +} + +// getTempDirPath returns the path to the system temp directory. If it does not exist - an error is returned. +func getTempDirPath() (string, error) { + tempDir := os.TempDir() + _, err := os.Stat(tempDir) + if err != nil { + return "", err + } + return tempDir, nil +} + +func getInaccessiblePath(file string) string { + return filepath.Join("C:", "file\x00name") // null byte in filename is not allowed on Windows AND unix +} diff --git a/plugin/route53/README.md b/plugin/route53/README.md new file mode 100644 index 0000000..d3f982d --- /dev/null +++ b/plugin/route53/README.md @@ -0,0 +1,131 @@ +# route53 + +## Name + +*route53* - enables serving zone data from AWS route53. + +## Description + +The route53 plugin is useful for serving zones from resource record +sets in AWS route53. This plugin supports all Amazon Route 53 records +([https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/ResourceRecordTypes.html)). +The route53 plugin can be used when CoreDNS is deployed on AWS or elsewhere. + +## Syntax + +~~~ txt +route53 [ZONE:HOSTED_ZONE_ID...] { + aws_access_key [AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY] # Deprecated, uses other authentication methods instead. + aws_endpoint ENDPOINT + credentials PROFILE [FILENAME] + fallthrough [ZONES...] + refresh DURATION +} +~~~ + +* **ZONE** the name of the domain to be accessed. When there are multiple zones with overlapping + domains (private vs. public hosted zone), CoreDNS does the lookup in the given order here. + Therefore, for a non-existing resource record, SOA response will be from the rightmost zone. + +* **HOSTED\_ZONE\_ID** the ID of the hosted zone that contains the resource record sets to be + accessed. + +* **AWS\_ACCESS\_KEY\_ID** and **AWS\_SECRET\_ACCESS\_KEY** the AWS access key ID and secret access key + to be used when querying AWS (optional). If they are not provided, CoreDNS tries to access + AWS credentials the same way as AWS CLI - environment variables, shared credential file (and optionally + shared config file if `AWS_SDK_LOAD_CONFIG` env is set), and lastly EC2 Instance Roles. + Note the usage of `aws_access_key` has been deprecated and may be removed in future versions. Instead, + user can use other methods to pass crentials, e.g., with environmental variable `AWS_ACCESS_KEY_ID` and + `AWS_SECRET_ACCESS_KEY`, respectively. + +* `aws_endpoint` can be used to control the endpoint to use when querying AWS (optional). **ENDPOINT** is the + URL of the endpoint to use. If this is not provided the default AWS endpoint resolution will occur. + +* `credentials` is used for overriding the shared credentials **FILENAME** and the **PROFILE** name for a + given zone. **PROFILE** is the AWS account profile name. Defaults to `default`. **FILENAME** is the + AWS shared credentials filename, defaults to `~/.aws/credentials`. CoreDNS will only load shared credentials + file and not shared config file (`~/.aws/config`) by default. Set `AWS_SDK_LOAD_CONFIG` env variable to + a truthy value to enable also loading of `~/.aws/config` (e.g. if you want to provide assumed IAM role + configuration). Will be ignored if static keys are set via `aws_access_key`. + +* `fallthrough` If zone matches and no record can be generated, pass request to the next plugin. + If **ZONES** is omitted, then fallthrough happens for all zones for which the plugin is + authoritative. If specific zones are listed (for example `in-addr.arpa` and `ip6.arpa`), then + only queries for those zones will be subject to fallthrough. + +* `refresh` can be used to control how long between record retrievals from Route 53. It requires + a duration string as a parameter to specify the duration between update cycles. Each update + cycle may result in many AWS API calls depending on how many domains use this plugin and how + many records are in each. Adjusting the update frequency may help reduce the potential of API + rate-limiting imposed by AWS. + +* **DURATION** A duration string. Defaults to `1m`. If units are unspecified, seconds are assumed. + +## Examples + +Enable route53 with implicit AWS credentials and resolve CNAMEs via 10.0.0.1: + +~~~ txt +example.org { + route53 example.org.:Z1Z2Z3Z4DZ5Z6Z7 +} + +. { + forward . 10.0.0.1 +} +~~~ + +Enable route53 with explicit AWS credentials: + +~~~ txt +example.org { + route53 example.org.:Z1Z2Z3Z4DZ5Z6Z7 { + aws_access_key AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY # Deprecated, uses other authentication methods instead. + } +} +~~~ + +Enable route53 with an explicit AWS endpoint: + +~~~ txt +example.org { + route53 example.org.:Z1Z2Z3Z4DZ5Z6Z7 { + aws_endpoint https://test.us-west-2.amazonaws.com + } +} +~~~ + +Enable route53 with fallthrough: + +~~~ txt +. { + route53 example.org.:Z1Z2Z3Z4DZ5Z6Z7 example.gov.:Z654321543245 { + fallthrough example.gov. + } +} +~~~ + +Enable route53 with multiple hosted zones with the same domain: + +~~~ txt +example.org { + route53 example.org.:Z1Z2Z3Z4DZ5Z6Z7 example.org.:Z93A52145678156 +} +~~~ + +Enable route53 and refresh records every 3 minutes +~~~ txt +example.org { + route53 example.org.:Z1Z2Z3Z4DZ5Z6Z7 { + refresh 3m + } +} +~~~ + +## Authentication + +Route53 plugin uses [AWS Go SDK](https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html) +for authentication, where there is a list of accepted configuration methods. +Note the usage of `aws_access_key` in Corefile has been deprecated and may be removed in future versions. Instead, +user can use other methods to pass crentials, e.g., with environmental variable `AWS_ACCESS_KEY_ID` and +`AWS_SECRET_ACCESS_KEY`, respectively. diff --git a/plugin/route53/log_test.go b/plugin/route53/log_test.go new file mode 100644 index 0000000..20d1f87 --- /dev/null +++ b/plugin/route53/log_test.go @@ -0,0 +1,5 @@ +package route53 + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/route53/route53.go b/plugin/route53/route53.go new file mode 100644 index 0000000..9f7a2e5 --- /dev/null +++ b/plugin/route53/route53.go @@ -0,0 +1,294 @@ +// Package route53 implements a plugin that returns resource records +// from AWS route53. +package route53 + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/pkg/upstream" + "github.com/coredns/coredns/request" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/route53" + "github.com/aws/aws-sdk-go/service/route53/route53iface" + "github.com/miekg/dns" +) + +// Route53 is a plugin that returns RR from AWS route53. +type Route53 struct { + Next plugin.Handler + Fall fall.F + + zoneNames []string + client route53iface.Route53API + upstream *upstream.Upstream + refresh time.Duration + + zMu sync.RWMutex + zones zones +} + +type zone struct { + id string + z *file.Zone + dns string +} + +type zones map[string][]*zone + +// New reads from the keys map which uses domain names as its key and hosted +// zone id lists as its values, validates that each domain name/zone id pair +// does exist, and returns a new *Route53. In addition to this, upstream is use +// for doing recursive queries against CNAMEs. Returns error if it cannot +// verify any given domain name/zone id pair. +func New(ctx context.Context, c route53iface.Route53API, keys map[string][]string, refresh time.Duration) (*Route53, error) { + zones := make(map[string][]*zone, len(keys)) + zoneNames := make([]string, 0, len(keys)) + for dns, hostedZoneIDs := range keys { + for _, hostedZoneID := range hostedZoneIDs { + _, err := c.ListHostedZonesByNameWithContext(ctx, &route53.ListHostedZonesByNameInput{ + DNSName: aws.String(dns), + HostedZoneId: aws.String(hostedZoneID), + }) + if err != nil { + return nil, err + } + if _, ok := zones[dns]; !ok { + zoneNames = append(zoneNames, dns) + } + zones[dns] = append(zones[dns], &zone{id: hostedZoneID, dns: dns, z: file.NewZone(dns, "")}) + } + } + return &Route53{ + client: c, + zoneNames: zoneNames, + zones: zones, + upstream: upstream.New(), + refresh: refresh, + }, nil +} + +// Run executes first update, spins up an update forever-loop. +// Returns error if first update fails. +func (h *Route53) Run(ctx context.Context) error { + if err := h.updateZones(ctx); err != nil { + return err + } + go func() { + timer := time.NewTimer(h.refresh) + defer timer.Stop() + for { + timer.Reset(h.refresh) + select { + case <-ctx.Done(): + log.Debugf("Breaking out of Route53 update loop for %v: %v", h.zoneNames, ctx.Err()) + return + case <-timer.C: + if err := h.updateZones(ctx); err != nil && ctx.Err() == nil /* Don't log error if ctx expired. */ { + log.Errorf("Failed to update zones %v: %v", h.zoneNames, err) + } + } + } + }() + return nil +} + +// ServeDNS implements the plugin.Handler.ServeDNS. +func (h *Route53) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + + zName := plugin.Zones(h.zoneNames).Matches(qname) + if zName == "" { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + z, ok := h.zones[zName] + if !ok || z == nil { + return dns.RcodeServerFailure, nil + } + + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + var result file.Result + for _, hostedZone := range z { + h.zMu.RLock() + m.Answer, m.Ns, m.Extra, result = hostedZone.z.Lookup(ctx, state, qname) + h.zMu.RUnlock() + + // Take the answer if it's non-empty OR if there is another + // record type exists for this name (NODATA). + if len(m.Answer) != 0 || result == file.NoData { + break + } + } + + if len(m.Answer) == 0 && result != file.NoData && h.Fall.Through(qname) { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + switch result { + case file.Success: + case file.NoData: + case file.NameError: + m.Rcode = dns.RcodeNameError + case file.Delegation: + m.Authoritative = false + case file.ServerFailure: + return dns.RcodeServerFailure, nil + } + + w.WriteMsg(m) + return dns.RcodeSuccess, nil +} + +const escapeSeq = "\\" + +// maybeUnescape parses s and converts escaped ASCII codepoints (in octal) back +// to its ASCII representation. +// +// From AWS docs: +// +// "If the domain name includes any characters other than a to z, 0 to 9, - +// (hyphen), or _ (underscore), Route 53 API actions return the characters as +// escape codes." +// +// For our purposes (and with respect to RFC 1035), we'll fish for a-z, 0-9, +// '-', '.' and '*' as the leftmost character (for wildcards) and throw error +// for everything else. +// +// Example: +// +// `\\052.example.com.` -> `*.example.com` +// `\\137.example.com.` -> error ('_' is not valid) +func maybeUnescape(s string) (string, error) { + var out string + for { + i := strings.Index(s, escapeSeq) + if i < 0 { + return out + s, nil + } + + out += s[:i] + + li, ri := i+len(escapeSeq), i+len(escapeSeq)+3 + if ri > len(s) { + return "", fmt.Errorf("invalid escape sequence: '%s%s'", escapeSeq, s[li:]) + } + // Parse `\\xxx` in base 8 (2nd arg) and attempt to fit into + // 8-bit result (3rd arg). + n, err := strconv.ParseInt(s[li:ri], 8, 8) + if err != nil { + return "", fmt.Errorf("invalid escape sequence: '%s%s'", escapeSeq, s[li:ri]) + } + + r := rune(n) + switch { + case r >= rune('a') && r <= rune('z'): // Route53 converts everything to lowercase. + case r >= rune('0') && r <= rune('9'): + case r == rune('*'): + if out != "" { + return "", errors.New("`*' only supported as wildcard (leftmost label)") + } + case r == rune('-'): + case r == rune('.'): + default: + return "", fmt.Errorf("invalid character: %s%#03o", escapeSeq, r) + } + + out += string(r) + + s = s[i+len(escapeSeq)+3:] + } +} + +func updateZoneFromRRS(rrs *route53.ResourceRecordSet, z *file.Zone) error { + for _, rr := range rrs.ResourceRecords { + n, err := maybeUnescape(aws.StringValue(rrs.Name)) + if err != nil { + return fmt.Errorf("failed to unescape `%s' name: %v", aws.StringValue(rrs.Name), err) + } + v, err := maybeUnescape(aws.StringValue(rr.Value)) + if err != nil { + return fmt.Errorf("failed to unescape `%s' value: %v", aws.StringValue(rr.Value), err) + } + + // Assemble RFC 1035 conforming record to pass into dns scanner. + rfc1035 := fmt.Sprintf("%s %d IN %s %s", n, aws.Int64Value(rrs.TTL), aws.StringValue(rrs.Type), v) + r, err := dns.NewRR(rfc1035) + if err != nil { + return fmt.Errorf("failed to parse resource record: %v", err) + } + + z.Insert(r) + } + return nil +} + +// updateZones re-queries resource record sets for each zone and updates the +// zone object. +// Returns error if any zones error'ed out, but waits for other zones to +// complete first. +func (h *Route53) updateZones(ctx context.Context) error { + errc := make(chan error) + defer close(errc) + for zName, z := range h.zones { + go func(zName string, z []*zone) { + var err error + defer func() { + errc <- err + }() + + for i, hostedZone := range z { + newZ := file.NewZone(zName, "") + newZ.Upstream = h.upstream + in := &route53.ListResourceRecordSetsInput{ + HostedZoneId: aws.String(hostedZone.id), + MaxItems: aws.String("1000"), + } + err = h.client.ListResourceRecordSetsPagesWithContext(ctx, in, + func(out *route53.ListResourceRecordSetsOutput, last bool) bool { + for _, rrs := range out.ResourceRecordSets { + if err := updateZoneFromRRS(rrs, newZ); err != nil { + // Maybe unsupported record type. Log and carry on. + log.Warningf("Failed to process resource record set: %v", err) + } + } + return true + }) + if err != nil { + err = fmt.Errorf("failed to list resource records for %v:%v from route53: %v", zName, hostedZone.id, err) + return + } + h.zMu.Lock() + (*z[i]).z = newZ + h.zMu.Unlock() + } + }(zName, z) + } + // Collect errors (if any). This will also sync on all zones updates + // completion. + var errs []string + for i := 0; i < len(h.zones); i++ { + err := <-errc + if err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) != 0 { + return fmt.Errorf("errors updating zones: %v", errs) + } + return nil +} + +// Name implements plugin.Handler.Name. +func (h *Route53) Name() string { return "route53" } diff --git a/plugin/route53/route53_test.go b/plugin/route53/route53_test.go new file mode 100644 index 0000000..d9b2fa1 --- /dev/null +++ b/plugin/route53/route53_test.go @@ -0,0 +1,298 @@ +package route53 + +import ( + "context" + "errors" + "reflect" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/test" + crequest "github.com/coredns/coredns/request" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/route53" + "github.com/aws/aws-sdk-go/service/route53/route53iface" + "github.com/miekg/dns" +) + +type fakeRoute53 struct { + route53iface.Route53API +} + +func (fakeRoute53) ListHostedZonesByNameWithContext(_ aws.Context, input *route53.ListHostedZonesByNameInput, _ ...request.Option) (*route53.ListHostedZonesByNameOutput, error) { + return nil, nil +} + +func (fakeRoute53) ListResourceRecordSetsPagesWithContext(_ aws.Context, in *route53.ListResourceRecordSetsInput, fn func(*route53.ListResourceRecordSetsOutput, bool) bool, _ ...request.Option) error { + if aws.StringValue(in.HostedZoneId) == "0987654321" { + return errors.New("bad. zone is bad") + } + rrsResponse := map[string][]*route53.ResourceRecordSet{} + for _, r := range []struct { + rType, name, value, hostedZoneID string + }{ + {"A", "example.org.", "1.2.3.4", "1234567890"}, + {"A", "www.example.org", "1.2.3.4", "1234567890"}, + {"CNAME", `\052.www.example.org`, "www.example.org", "1234567890"}, + {"AAAA", "example.org.", "2001:db8:85a3::8a2e:370:7334", "1234567890"}, + {"CNAME", "sample.example.org.", "example.org", "1234567890"}, + {"PTR", "example.org.", "ptr.example.org.", "1234567890"}, + {"SOA", "org.", "ns-1536.awsdns-00.co.uk. awsdns-hostmaster.amazon.com. 1 7200 900 1209600 86400", "1234567890"}, + {"NS", "com.", "ns-1536.awsdns-00.co.uk.", "1234567890"}, + {"A", "split-example.gov.", "1.2.3.4", "1234567890"}, + // Unsupported type should be ignored. + {"YOLO", "swag.", "foobar", "1234567890"}, + // Hosted zone with the same name, but a different id. + {"A", "other-example.org.", "3.5.7.9", "1357986420"}, + {"A", "split-example.org.", "1.2.3.4", "1357986420"}, + {"SOA", "org.", "ns-15.awsdns-00.co.uk. awsdns-hostmaster.amazon.com. 1 7200 900 1209600 86400", "1357986420"}, + // Hosted zone without SOA. + } { + rrs, ok := rrsResponse[r.hostedZoneID] + if !ok { + rrs = make([]*route53.ResourceRecordSet, 0) + } + rrs = append(rrs, &route53.ResourceRecordSet{Type: aws.String(r.rType), + Name: aws.String(r.name), + ResourceRecords: []*route53.ResourceRecord{ + { + Value: aws.String(r.value), + }, + }, + TTL: aws.Int64(300), + }) + rrsResponse[r.hostedZoneID] = rrs + } + + if ok := fn(&route53.ListResourceRecordSetsOutput{ + ResourceRecordSets: rrsResponse[aws.StringValue(in.HostedZoneId)], + }, true); !ok { + return errors.New("paging function return false") + } + return nil +} + +func TestRoute53(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := New(ctx, fakeRoute53{}, map[string][]string{"bad.": {"0987654321"}}, time.Minute) + if err != nil { + t.Fatalf("Failed to create route53: %v", err) + } + if err = r.Run(ctx); err == nil { + t.Fatalf("Expected errors for zone bad.") + } + + r, err = New(ctx, fakeRoute53{}, map[string][]string{"org.": {"1357986420", "1234567890"}, "gov.": {"Z098765432", "1234567890"}}, 90*time.Second) + if err != nil { + t.Fatalf("Failed to create route53: %v", err) + } + r.Fall = fall.Zero + r.Fall.SetZonesFromArgs([]string{"gov."}) + r.Next = test.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := crequest.Request{W: w, Req: r} + qname := state.Name() + m := new(dns.Msg) + rcode := dns.RcodeServerFailure + if qname == "example.gov." { + m.SetReply(r) + rr, err := dns.NewRR("example.gov. 300 IN A 2.4.6.8") + if err != nil { + t.Fatalf("Failed to create Resource Record: %v", err) + } + m.Answer = []dns.RR{rr} + + m.Authoritative = true + rcode = dns.RcodeSuccess + } + + m.SetRcode(r, rcode) + w.WriteMsg(m) + return rcode, nil + }) + err = r.Run(ctx) + if err != nil { + t.Fatalf("Failed to initialize route53: %v", err) + } + + tests := []struct { + qname string + qtype uint16 + wantRetCode int + wantAnswer []string // ownernames for the records in the additional section. + wantMsgRCode int + wantNS []string + expectedErr error + }{ + // 0. example.org A found - success. + { + qname: "example.org", + qtype: dns.TypeA, + wantAnswer: []string{"example.org. 300 IN A 1.2.3.4"}, + }, + // 1. example.org AAAA found - success. + { + qname: "example.org", + qtype: dns.TypeAAAA, + wantAnswer: []string{"example.org. 300 IN AAAA 2001:db8:85a3::8a2e:370:7334"}, + }, + // 2. exampled.org PTR found - success. + { + qname: "example.org", + qtype: dns.TypePTR, + wantAnswer: []string{"example.org. 300 IN PTR ptr.example.org."}, + }, + // 3. sample.example.org points to example.org CNAME. + // Query must return both CNAME and A recs. + { + qname: "sample.example.org", + qtype: dns.TypeA, + wantAnswer: []string{ + "sample.example.org. 300 IN CNAME example.org.", + "example.org. 300 IN A 1.2.3.4", + }, + }, + // 4. Explicit CNAME query for sample.example.org. + // Query must return just CNAME. + { + qname: "sample.example.org", + qtype: dns.TypeCNAME, + wantAnswer: []string{"sample.example.org. 300 IN CNAME example.org."}, + }, + // 5. Explicit SOA query for example.org. + { + qname: "example.org", + qtype: dns.TypeNS, + wantNS: []string{"org. 300 IN SOA ns-1536.awsdns-00.co.uk. awsdns-hostmaster.amazon.com. 1 7200 900 1209600 86400"}, + }, + // 6. AAAA query for split-example.org must return NODATA. + { + qname: "split-example.gov", + qtype: dns.TypeAAAA, + wantRetCode: dns.RcodeSuccess, + wantNS: []string{"org. 300 IN SOA ns-1536.awsdns-00.co.uk. awsdns-hostmaster.amazon.com. 1 7200 900 1209600 86400"}, + }, + // 7. Zone not configured. + { + qname: "badexample.com", + qtype: dns.TypeA, + wantRetCode: dns.RcodeServerFailure, + wantMsgRCode: dns.RcodeServerFailure, + }, + // 8. No record found. Return SOA record. + { + qname: "bad.org", + qtype: dns.TypeA, + wantRetCode: dns.RcodeSuccess, + wantMsgRCode: dns.RcodeNameError, + wantNS: []string{"org. 300 IN SOA ns-1536.awsdns-00.co.uk. awsdns-hostmaster.amazon.com. 1 7200 900 1209600 86400"}, + }, + // 9. No record found. Fallthrough. + { + qname: "example.gov", + qtype: dns.TypeA, + wantAnswer: []string{"example.gov. 300 IN A 2.4.6.8"}, + }, + // 10. other-zone.example.org is stored in a different hosted zone. success + { + qname: "other-example.org", + qtype: dns.TypeA, + wantAnswer: []string{"other-example.org. 300 IN A 3.5.7.9"}, + }, + // 11. split-example.org only has A record. Expect NODATA. + { + qname: "split-example.org", + qtype: dns.TypeAAAA, + wantNS: []string{"org. 300 IN SOA ns-15.awsdns-00.co.uk. awsdns-hostmaster.amazon.com. 1 7200 900 1209600 86400"}, + }, + // 12. *.www.example.org is a wildcard CNAME to www.example.org. + { + qname: "a.www.example.org", + qtype: dns.TypeA, + wantAnswer: []string{ + "a.www.example.org. 300 IN CNAME www.example.org.", + "www.example.org. 300 IN A 1.2.3.4", + }, + }, + } + + for ti, tc := range tests { + req := new(dns.Msg) + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + code, err := r.ServeDNS(ctx, rec, req) + + if err != tc.expectedErr { + t.Fatalf("Test %d: Expected error %v, but got %v", ti, tc.expectedErr, err) + } + if code != tc.wantRetCode { + t.Fatalf("Test %d: Expected returned status code %s, but got %s", ti, dns.RcodeToString[tc.wantRetCode], dns.RcodeToString[code]) + } + + if tc.wantMsgRCode != rec.Msg.Rcode { + t.Errorf("Test %d: Unexpected msg status code. Want: %s, got: %s", ti, dns.RcodeToString[tc.wantMsgRCode], dns.RcodeToString[rec.Msg.Rcode]) + } + + if len(tc.wantAnswer) != len(rec.Msg.Answer) { + t.Errorf("Test %d: Unexpected number of Answers. Want: %d, got: %d", ti, len(tc.wantAnswer), len(rec.Msg.Answer)) + } else { + for i, gotAnswer := range rec.Msg.Answer { + if gotAnswer.String() != tc.wantAnswer[i] { + t.Errorf("Test %d: Unexpected answer.\nWant:\n\t%s\nGot:\n\t%s", ti, tc.wantAnswer[i], gotAnswer) + } + } + } + + if len(tc.wantNS) != len(rec.Msg.Ns) { + t.Errorf("Test %d: Unexpected NS number. Want: %d, got: %d", ti, len(tc.wantNS), len(rec.Msg.Ns)) + } else { + for i, ns := range rec.Msg.Ns { + got, ok := ns.(*dns.SOA) + if !ok { + t.Errorf("Test %d: Unexpected NS type. Want: SOA, got: %v", ti, reflect.TypeOf(got)) + } + if got.String() != tc.wantNS[i] { + t.Errorf("Test %d: Unexpected NS.\nWant: %v\nGot: %v", ti, tc.wantNS[i], got) + } + } + } + } +} + +func TestMaybeUnescape(t *testing.T) { + for ti, tc := range []struct { + escaped, want string + wantErr error + }{ + // 0. empty string is fine. + {escaped: "", want: ""}, + // 1. non-escaped sequence. + {escaped: "example.com.", want: "example.com."}, + // 2. escaped `*` as first label - OK. + {escaped: `\052.example.com`, want: "*.example.com"}, + // 3. Escaped dot, 'a' and a hyphen. No idea why but we'll allow it. + {escaped: `weird\055ex\141mple\056com\056\056`, want: "weird-example.com.."}, + // 4. escaped `*` in the middle - NOT OK. + {escaped: `e\052ample.com`, wantErr: errors.New("`*' only supported as wildcard (leftmost label)")}, + // 5. Invalid character. + {escaped: `\000.example.com`, wantErr: errors.New(`invalid character: \000`)}, + // 6. Invalid escape sequence in the middle. + {escaped: `example\0com`, wantErr: errors.New(`invalid escape sequence: '\0co'`)}, + // 7. Invalid escape sequence at the end. + {escaped: `example.com\0`, wantErr: errors.New(`invalid escape sequence: '\0'`)}, + } { + got, gotErr := maybeUnescape(tc.escaped) + if tc.wantErr != gotErr && !reflect.DeepEqual(tc.wantErr, gotErr) { + t.Fatalf("Test %d: Expected error: `%v', but got: `%v'", ti, tc.wantErr, gotErr) + } + if tc.want != got { + t.Errorf("Test %d: Expected unescaped: `%s', but got: `%s'", ti, tc.want, got) + } + } +} diff --git a/plugin/route53/setup.go b/plugin/route53/setup.go new file mode 100644 index 0000000..3df6527 --- /dev/null +++ b/plugin/route53/setup.go @@ -0,0 +1,144 @@ +package route53 + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/fall" + clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/defaults" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/route53" + "github.com/aws/aws-sdk-go/service/route53/route53iface" +) + +var log = clog.NewWithPlugin("route53") + +func init() { plugin.Register("route53", setup) } + +// exposed for testing +var f = func(opts session.Options) route53iface.Route53API { + return route53.New(session.Must(session.NewSessionWithOptions(opts))) +} + +func setup(c *caddy.Controller) error { + for c.Next() { + keyPairs := map[string]struct{}{} + keys := map[string][]string{} + + // Route53 plugin attempts to load AWS credentials following default SDK chaining. + // The order configuration is loaded in is: + // * Static AWS keys set in Corefile (deprecated) + // * Environment Variables + // * Shared Credentials file + // * Shared Configuration file (if AWS_SDK_LOAD_CONFIG is set to truthy value) + // * EC2 Instance Metadata (credentials only) + opts := session.Options{} + var fall fall.F + + refresh := time.Duration(1) * time.Minute // default update frequency to 1 minute + + args := c.RemainingArgs() + + for i := 0; i < len(args); i++ { + parts := strings.SplitN(args[i], ":", 2) + if len(parts) != 2 { + return plugin.Error("route53", c.Errf("invalid zone %q", args[i])) + } + dns, hostedZoneID := parts[0], parts[1] + if dns == "" || hostedZoneID == "" { + return plugin.Error("route53", c.Errf("invalid zone %q", args[i])) + } + if _, ok := keyPairs[args[i]]; ok { + return plugin.Error("route53", c.Errf("conflict zone %q", args[i])) + } + + keyPairs[args[i]] = struct{}{} + keys[dns] = append(keys[dns], hostedZoneID) + } + + for c.NextBlock() { + switch c.Val() { + case "aws_access_key": + v := c.RemainingArgs() + if len(v) < 2 { + return plugin.Error("route53", c.Errf("invalid access key: '%v'", v)) + } + opts.Config.Credentials = credentials.NewStaticCredentials(v[0], v[1], "") + log.Warningf("Save aws_access_key in Corefile has been deprecated, please use other authentication methods instead") + case "aws_endpoint": + if c.NextArg() { + opts.Config.Endpoint = aws.String(c.Val()) + } else { + return plugin.Error("route53", c.ArgErr()) + } + case "upstream": + c.RemainingArgs() // eats args + case "credentials": + if c.NextArg() { + opts.Profile = c.Val() + } else { + return c.ArgErr() + } + if c.NextArg() { + opts.SharedConfigFiles = []string{c.Val()} + // If AWS_SDK_LOAD_CONFIG is set also load ~/.aws/config to stay consistent + // with default SDK behavior. + if ok, _ := strconv.ParseBool(os.Getenv("AWS_SDK_LOAD_CONFIG")); ok { + opts.SharedConfigFiles = append(opts.SharedConfigFiles, defaults.SharedConfigFilename()) + } + } + case "fallthrough": + fall.SetZonesFromArgs(c.RemainingArgs()) + case "refresh": + if c.NextArg() { + refreshStr := c.Val() + _, err := strconv.Atoi(refreshStr) + if err == nil { + refreshStr = fmt.Sprintf("%ss", c.Val()) + } + refresh, err = time.ParseDuration(refreshStr) + if err != nil { + return plugin.Error("route53", c.Errf("Unable to parse duration: %v", err)) + } + if refresh <= 0 { + return plugin.Error("route53", c.Errf("refresh interval must be greater than 0: %q", refreshStr)) + } + } else { + return plugin.Error("route53", c.ArgErr()) + } + default: + return plugin.Error("route53", c.Errf("unknown property %q", c.Val())) + } + } + + client := f(opts) + ctx, cancel := context.WithCancel(context.Background()) + h, err := New(ctx, client, keys, refresh) + if err != nil { + cancel() + return plugin.Error("route53", c.Errf("failed to create route53 plugin: %v", err)) + } + h.Fall = fall + if err := h.Run(ctx); err != nil { + cancel() + return plugin.Error("route53", c.Errf("failed to initialize route53 plugin: %v", err)) + } + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + h.Next = next + return h + }) + c.OnShutdown(func() error { cancel(); return nil }) + } + return nil +} diff --git a/plugin/route53/setup_test.go b/plugin/route53/setup_test.go new file mode 100644 index 0000000..5d2792f --- /dev/null +++ b/plugin/route53/setup_test.go @@ -0,0 +1,87 @@ +package route53 + +import ( + "testing" + + "github.com/coredns/caddy" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/route53/route53iface" +) + +func TestSetupRoute53(t *testing.T) { + f = func(opts session.Options) route53iface.Route53API { + return fakeRoute53{} + } + + tests := []struct { + body string + expectedError bool + }{ + {`route53`, false}, + {`route53 :`, true}, + {`route53 example.org:12345678`, false}, + {`route53 example.org:12345678 { + aws_access_key +}`, true}, + {`route53 example.org:12345678 { }`, false}, + + {`route53 example.org:12345678 { }`, false}, + {`route53 example.org:12345678 { wat +}`, true}, + {`route53 example.org:12345678 { + aws_access_key ACCESS_KEY_ID SEKRIT_ACCESS_KEY +}`, false}, + + {`route53 example.org:12345678 { + fallthrough +}`, false}, + {`route53 example.org:12345678 { + credentials + }`, true}, + + {`route53 example.org:12345678 { + credentials default + }`, false}, + {`route53 example.org:12345678 { + credentials default credentials + }`, false}, + {`route53 example.org:12345678 { + credentials default credentials extra-arg + }`, true}, + {`route53 example.org:12345678 example.org:12345678 { + }`, true}, + + {`route53 example.org:12345678 { + refresh 90 +}`, false}, + {`route53 example.org:12345678 { + refresh 5m +}`, false}, + {`route53 example.org:12345678 { + refresh +}`, true}, + {`route53 example.org:12345678 { + refresh foo +}`, true}, + {`route53 example.org:12345678 { + refresh -1m +}`, true}, + + {`route53 example.org { + }`, true}, + {`route53 example.org:12345678 { + aws_endpoint +}`, true}, + {`route53 example.org:12345678 { + aws_endpoint https://localhost +}`, false}, + } + + for _, test := range tests { + c := caddy.NewTestController("dns", test.body) + if err := setup(c); (err == nil) == test.expectedError { + t.Errorf("Unexpected errors: %v", err) + } + } +} diff --git a/plugin/secondary/README.md b/plugin/secondary/README.md new file mode 100644 index 0000000..b22965e --- /dev/null +++ b/plugin/secondary/README.md @@ -0,0 +1,73 @@ +# secondary + +## Name + +*secondary* - enables serving a zone retrieved from a primary server. + +## Description + +With *secondary* you can transfer (via AXFR) a zone from another server. The retrieved zone is +*not committed* to disk (a violation of the RFC). This means restarting CoreDNS will cause it to +retrieve all secondary zones. + +If the primary server(s) don't respond when CoreDNS is starting up, the AXFR will be retried +indefinitely every 10s. + +## Syntax + +~~~ +secondary [ZONES...] +~~~ + +* **ZONES** zones it should be authoritative for. If empty, the zones from the configuration block + are used. Note that without a remote address to *get* the zone from, the above is not that useful. + +A working syntax would be: + +~~~ +secondary [zones...] { + transfer from ADDRESS [ADDRESS...] +} +~~~ + +* `transfer from` specifies from which **ADDRESS** to fetch the zone. It can be specified multiple + times; if one does not work, another will be tried. Transferring this zone outwards again can be + done by enabling the *transfer* plugin. + +When a zone is due to be refreshed (refresh timer fires) a random jitter of 5 seconds is applied, +before fetching. In the case of retry this will be 2 seconds. If there are any errors during the +transfer in, the transfer fails; this will be logged. + +## Examples + +Transfer `example.org` from 10.0.1.1, and if that fails try 10.1.2.1. + +~~~ corefile +example.org { + secondary { + transfer from 10.0.1.1 10.1.2.1 + } +} +~~~ + +Or re-export the retrieved zone to other secondaries. + +~~~ corefile +example.net { + secondary { + transfer from 10.1.2.1 + } + transfer { + to * + } +} +~~~ + +## Bugs + +Only AXFR is supported and the retrieved zone is not committed to disk. + +## See Also + +See the *transfer* plugin to enable zone transfers _to_ other servers. +And RFC 5936 detailing the AXFR protocol. diff --git a/plugin/secondary/log_test.go b/plugin/secondary/log_test.go new file mode 100644 index 0000000..15cab00 --- /dev/null +++ b/plugin/secondary/log_test.go @@ -0,0 +1,5 @@ +package secondary + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/secondary/secondary.go b/plugin/secondary/secondary.go new file mode 100644 index 0000000..809edad --- /dev/null +++ b/plugin/secondary/secondary.go @@ -0,0 +1,13 @@ +// Package secondary implements a secondary plugin. +package secondary + +import "github.com/coredns/coredns/plugin/file" + +// Secondary implements a secondary plugin that allows CoreDNS to retrieve (via AXFR) +// zone information from a primary server. +type Secondary struct { + file.File +} + +// Name implements the Handler interface. +func (s Secondary) Name() string { return "secondary" } diff --git a/plugin/secondary/setup.go b/plugin/secondary/setup.go new file mode 100644 index 0000000..22f0d32 --- /dev/null +++ b/plugin/secondary/setup.go @@ -0,0 +1,99 @@ +package secondary + +import ( + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/file" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/upstream" +) + +var log = clog.NewWithPlugin("secondary") + +func init() { plugin.Register("secondary", setup) } + +func setup(c *caddy.Controller) error { + zones, err := secondaryParse(c) + if err != nil { + return plugin.Error("secondary", err) + } + + // Add startup functions to retrieve the zone and keep it up to date. + for i := range zones.Names { + n := zones.Names[i] + z := zones.Z[n] + if len(z.TransferFrom) > 0 { + c.OnStartup(func() error { + z.StartupOnce.Do(func() { + go func() { + dur := time.Millisecond * 250 + step := time.Duration(2) + max := time.Second * 10 + for { + err := z.TransferIn() + if err == nil { + break + } + log.Warningf("All '%s' masters failed to transfer, retrying in %s: %s", n, dur.String(), err) + time.Sleep(dur) + dur = step * dur + if dur > max { + dur = max + } + } + z.Update() + }() + }) + return nil + }) + } + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Secondary{file.File{Next: next, Zones: zones}} + }) + + return nil +} + +func secondaryParse(c *caddy.Controller) (file.Zones, error) { + z := make(map[string]*file.Zone) + names := []string{} + for c.Next() { + if c.Val() == "secondary" { + // secondary [origin] + origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + for i := range origins { + z[origins[i]] = file.NewZone(origins[i], "stdin") + names = append(names, origins[i]) + } + + for c.NextBlock() { + var f []string + + switch c.Val() { + case "transfer": + var err error + f, err = parse.TransferIn(c) + if err != nil { + return file.Zones{}, err + } + default: + return file.Zones{}, c.Errf("unknown property '%s'", c.Val()) + } + + for _, origin := range origins { + if f != nil { + z[origin].TransferFrom = append(z[origin].TransferFrom, f...) + } + z[origin].Upstream = upstream.New() + } + } + } + } + return file.Zones{Z: z, Names: names}, nil +} diff --git a/plugin/secondary/setup_test.go b/plugin/secondary/setup_test.go new file mode 100644 index 0000000..4985ec5 --- /dev/null +++ b/plugin/secondary/setup_test.go @@ -0,0 +1,63 @@ +package secondary + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSecondaryParse(t *testing.T) { + tests := []struct { + inputFileRules string + shouldErr bool + transferFrom string + zones []string + }{ + { + `secondary`, + false, // TODO(miek): should actually be true, because without transfer lines this does not make sense + "", + nil, + }, + { + `secondary { + transfer from 127.0.0.1 + }`, + false, + "127.0.0.1:53", + nil, + }, + { + `secondary example.org { + transfer from 127.0.0.1 + }`, + false, + "127.0.0.1:53", + []string{"example.org."}, + }, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + s, err := secondaryParse(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } + + for i, name := range test.zones { + if x := s.Names[i]; x != name { + t.Fatalf("Test %d zone names don't match expected %q, but got %q", i, name, x) + } + } + + // This is only set *if* we have a zone (i.e. not in all tests above) + for _, v := range s.Z { + if x := v.TransferFrom[0]; x != test.transferFrom { + t.Fatalf("Test %d transform from names don't match expected %q, but got %q", i, test.transferFrom, x) + } + } + } +} diff --git a/plugin/sign/README.md b/plugin/sign/README.md new file mode 100644 index 0000000..6eb4ba8 --- /dev/null +++ b/plugin/sign/README.md @@ -0,0 +1,168 @@ +# sign + +## Name + +*sign* - adds DNSSEC records to zone files. + +## Description + +The *sign* plugin is used to sign (see RFC 6781) zones. In this process DNSSEC resource records are +added. The signatures that sign the resource records sets have an expiration date, this means the +signing process must be repeated before this expiration data is reached. Otherwise the zone's data +will go BAD (RFC 4035, Section 5.5). The *sign* plugin takes care of this. + +Only NSEC is supported, *sign* does *not* support NSEC3. + +*Sign* works in conjunction with the *file* and *auto* plugins; this plugin **signs** the zones +files, *auto* and *file* **serve** the zones *data*. + +For this plugin to work at least one Common Signing Key, (see coredns-keygen(1)) is needed. This key +(or keys) will be used to sign the entire zone. *Sign* does *not* support the ZSK/KSK split, nor will +it do key or algorithm rollovers - it just signs. + +*Sign* will: + + * (Re)-sign the zone with the CSK(s) when: + + - the last time it was signed is more than a 6 days ago. Each zone will have some jitter + applied to the inception date. + + - the signature only has 14 days left before expiring. + + Both these dates are only checked on the SOA's signature(s). + + * Create RRSIGs that have an inception of -3 hours (minus a jitter between 0 and 18 hours) + and a expiration of +32 (plus a jitter between 0 and 5 days) days for every given DNSKEY. + + * Add NSEC records for all names in the zone. The TTL for these is the negative cache TTL from the + SOA record. + + * Add or replace *all* apex CDS/CDNSKEY records with the ones derived from the given keys. For + each key two CDS are created one with SHA1 and another with SHA256. + + * Update the SOA's serial number to the *Unix epoch* of when the signing happens. This will + overwrite *any* previous serial number. + + +There are two ways that dictate when a zone is signed. Normally every 6 days (plus jitter) it will +be resigned. If for some reason we fail this check, the 14 days before expiring kicks in. + +Keys are named (following BIND9): `K<name>+<alg>+<id>.key` and `K<name>+<alg>+<id>.private`. +The keys **must not** be included in your zone; they will be added by *sign*. These keys can be +generated with `coredns-keygen` or BIND9's `dnssec-keygen`. You don't have to adhere to this naming +scheme, but then you need to name your keys explicitly, see the `keys file` directive. + +A generated zone is written out in a file named `db.<name>.signed` in the directory named by the +`directory` directive (which defaults to `/var/lib/coredns`). + +## Syntax + +~~~ +sign DBFILE [ZONES...] { + key file|directory KEY...|DIR... + directory DIR +} +~~~ + +* **DBFILE** the zone database file to read and parse. If the path is relative, the path from the + *root* plugin will be prepended to it. +* **ZONES** zones it should be sign for. If empty, the zones from the configuration block are + used. +* `key` specifies the key(s) (there can be multiple) to sign the zone. If `file` is + used the **KEY**'s filenames are used as is. If `directory` is used, *sign* will look in **DIR** + for `K<name>+<alg>+<id>` files. Any metadata in these files (Activate, Publish, etc.) is + *ignored*. These keys must also be Key Signing Keys (KSK). +* `directory` specifies the **DIR** where CoreDNS should save zones that have been signed. + If not given this defaults to `/var/lib/coredns`. The zones are saved under the name + `db.<name>.signed`. If the path is relative the path from the *root* plugin will be prepended + to it. + +Keys can be generated with `coredns-keygen`, to create one for use in the *sign* plugin, use: +`coredns-keygen example.org` or `dnssec-keygen -a ECDSAP256SHA256 -f KSK example.org`. + +## Examples + +Sign the `example.org` zone contained in the file `db.example.org` and write the result to +`./db.example.org.signed` to let the *file* plugin pick it up and serve it. The keys used +are read from `/etc/coredns/keys/Kexample.org.key` and `/etc/coredns/keys/Kexample.org.private`. + +~~~ txt +example.org { + file db.example.org.signed + + sign db.example.org { + key file /etc/coredns/keys/Kexample.org + directory . + } +} +~~~ + +Running this leads to the following log output (note the timers in this example have been set to +shorter intervals). + +~~~ txt +[WARNING] plugin/file: Failed to open "open /tmp/db.example.org.signed: no such file or directory": trying again in 1m0s +[INFO] plugin/sign: Signing "example.org." because open /tmp/db.example.org.signed: no such file or directory +[INFO] plugin/sign: Successfully signed zone "example.org." in "/tmp/db.example.org.signed" with key tags "59725" and 1564766865 SOA serial, elapsed 9.357933ms, next: 2019-08-02T22:27:45.270Z +[INFO] plugin/file: Successfully reloaded zone "example.org." in "/tmp/db.example.org.signed" with serial 1564766865 +~~~ + +Or use a single zone file for *multiple* zones, note that the **ZONES** are repeated for both plugins. +Also note this outputs *multiple* signed output files. Here we use the default output directory +`/var/lib/coredns`. + +~~~ txt +. { + file /var/lib/coredns/db.example.org.signed example.org + file /var/lib/coredns/db.example.net.signed example.net + sign db.example.org example.org example.net { + key directory /etc/coredns/keys + } +} +~~~ + +This is the same configuration, but the zones are put in the server block, but note that you still +need to specify what file is served for what zone in the *file* plugin: + +~~~ txt +example.org example.net { + file var/lib/coredns/db.example.org.signed example.org + file var/lib/coredns/db.example.net.signed example.net + sign db.example.org { + key directory /etc/coredns/keys + } +} +~~~ + +Be careful to fully list the origins you want to sign, if you don't: + +~~~ txt +example.org example.net { + sign plugin/sign/testdata/db.example.org miek.org { + key file /etc/coredns/keys/Kexample.org + } +} +~~~ + +This will lead to `db.example.org` be signed *twice*, as this entire section is parsed twice because +you have specified the origins `example.org` and `example.net` in the server block. + +Forcibly resigning a zone can be accomplished by removing the signed zone file (CoreDNS will keep +on serving it from memory), and sending SIGUSR1 to the process to make it reload and resign the zone +file. + +## See Also + +The DNSSEC RFCs: RFC 4033, RFC 4034 and RFC 4035. And the BCP on DNSSEC, RFC 6781. Further more the +manual pages coredns-keygen(1) and dnssec-keygen(8). And the *file* plugin's documentation. + +Coredns-keygen can be found at +[https://github.com/coredns/coredns-utils](https://github.com/coredns/coredns-utils) in the +coredns-keygen directory. + +Other useful DNSSEC tools can be found in [ldns](https://nlnetlabs.nl/projects/ldns/about/), e.g. +`ldns-key2ds` to create DS records from DNSKEYs. + +## Bugs + +`keys directory` is not implemented. diff --git a/plugin/sign/dnssec.go b/plugin/sign/dnssec.go new file mode 100644 index 0000000..a95e086 --- /dev/null +++ b/plugin/sign/dnssec.go @@ -0,0 +1,20 @@ +package sign + +import ( + "github.com/miekg/dns" +) + +func (p Pair) signRRs(rrs []dns.RR, signerName string, ttl, incep, expir uint32) (*dns.RRSIG, error) { + rrsig := &dns.RRSIG{ + Hdr: dns.RR_Header{Rrtype: dns.TypeRRSIG, Ttl: ttl}, + Algorithm: p.Public.Algorithm, + SignerName: signerName, + KeyTag: p.KeyTag, + OrigTtl: ttl, + Inception: incep, + Expiration: expir, + } + + e := rrsig.Sign(p.Private, rrs) + return rrsig, e +} diff --git a/plugin/sign/file.go b/plugin/sign/file.go new file mode 100644 index 0000000..194ab69 --- /dev/null +++ b/plugin/sign/file.go @@ -0,0 +1,92 @@ +package sign + +import ( + "fmt" + "io" + "os" + "path/filepath" + + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/file/tree" + + "github.com/miekg/dns" +) + +// write writes out the zone file to a temporary file which is then moved into the correct place. +func (s *Signer) write(z *file.Zone) error { + f, err := os.CreateTemp(s.directory, "signed-") + if err != nil { + return err + } + + if err := write(f, z); err != nil { + f.Close() + return err + } + + f.Close() + return os.Rename(f.Name(), filepath.Join(s.directory, s.signedfile)) +} + +func write(w io.Writer, z *file.Zone) error { + if _, err := io.WriteString(w, z.Apex.SOA.String()); err != nil { + return err + } + w.Write([]byte("\n")) // RR Stringer() method doesn't include newline, which ends the RR in a zone file, write that here. + for _, rr := range z.Apex.SIGSOA { + io.WriteString(w, rr.String()) + w.Write([]byte("\n")) + } + for _, rr := range z.Apex.NS { + io.WriteString(w, rr.String()) + w.Write([]byte("\n")) + } + for _, rr := range z.Apex.SIGNS { + io.WriteString(w, rr.String()) + w.Write([]byte("\n")) + } + err := z.Walk(func(e *tree.Elem, _ map[uint16][]dns.RR) error { + for _, r := range e.All() { + io.WriteString(w, r.String()) + w.Write([]byte("\n")) + } + return nil + }) + return err +} + +// Parse parses the zone in filename and returns a new Zone or an error. This +// is similar to the Parse function in the *file* plugin. However when parsing +// the record types DNSKEY, RRSIG, CDNSKEY and CDS are *not* included in the returned +// zone (if encountered). +func Parse(f io.Reader, origin, fileName string) (*file.Zone, error) { + zp := dns.NewZoneParser(f, dns.Fqdn(origin), fileName) + zp.SetIncludeAllowed(true) + z := file.NewZone(origin, fileName) + seenSOA := false + + for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { + if err := zp.Err(); err != nil { + return nil, err + } + + switch rr.(type) { + case *dns.DNSKEY, *dns.RRSIG, *dns.CDNSKEY, *dns.CDS: + continue + case *dns.SOA: + seenSOA = true + if err := z.Insert(rr); err != nil { + return nil, err + } + default: + if err := z.Insert(rr); err != nil { + return nil, err + } + } + } + if !seenSOA { + return nil, fmt.Errorf("file %q has no SOA record", fileName) + } + + return z, nil +} diff --git a/plugin/sign/file_test.go b/plugin/sign/file_test.go new file mode 100644 index 0000000..72d2b02 --- /dev/null +++ b/plugin/sign/file_test.go @@ -0,0 +1,43 @@ +package sign + +import ( + "os" + "testing" + + "github.com/miekg/dns" +) + +func TestFileParse(t *testing.T) { + f, err := os.Open("testdata/db.miek.nl") + if err != nil { + t.Fatal(err) + } + z, err := Parse(f, "miek.nl.", "testdata/db.miek.nl") + if err != nil { + t.Fatal(err) + } + s := &Signer{ + directory: ".", + signedfile: "db.miek.nl.test", + } + + s.write(z) + defer os.Remove("db.miek.nl.test") + + f, err = os.Open("db.miek.nl.test") + if err != nil { + t.Fatal(err) + } + z, err = Parse(f, "miek.nl.", "db.miek.nl.test") + if err != nil { + t.Fatal(err) + } + if x := z.Apex.SOA.Header().Name; x != "miek.nl." { + t.Errorf("Expected SOA name to be %s, got %s", x, "miek.nl.") + } + apex, _ := z.Search("miek.nl.") + key := apex.Type(dns.TypeDNSKEY) + if key != nil { + t.Errorf("Expected no DNSKEYs, but got %d", len(key)) + } +} diff --git a/plugin/sign/keys.go b/plugin/sign/keys.go new file mode 100644 index 0000000..b999584 --- /dev/null +++ b/plugin/sign/keys.go @@ -0,0 +1,119 @@ +package sign + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + + "github.com/miekg/dns" + "golang.org/x/crypto/ed25519" +) + +// Pair holds DNSSEC key information, both the public and private components are stored here. +type Pair struct { + Public *dns.DNSKEY + KeyTag uint16 + Private crypto.Signer +} + +// keyParse reads the public and private key from disk. +func keyParse(c *caddy.Controller) ([]Pair, error) { + if !c.NextArg() { + return nil, c.ArgErr() + } + pairs := []Pair{} + config := dnsserver.GetConfig(c) + + switch c.Val() { + case "file": + ks := c.RemainingArgs() + if len(ks) == 0 { + return nil, c.ArgErr() + } + for _, k := range ks { + base := k + // Kmiek.nl.+013+26205.key, handle .private or without extension: Kmiek.nl.+013+26205 + if strings.HasSuffix(k, ".key") { + base = k[:len(k)-4] + } + if strings.HasSuffix(k, ".private") { + base = k[:len(k)-8] + } + if !filepath.IsAbs(base) && config.Root != "" { + base = filepath.Join(config.Root, base) + } + + pair, err := readKeyPair(base+".key", base+".private") + if err != nil { + return nil, err + } + pairs = append(pairs, pair) + } + case "directory": + return nil, fmt.Errorf("directory: not implemented") + } + + return pairs, nil +} + +func readKeyPair(public, private string) (Pair, error) { + rk, err := os.Open(filepath.Clean(public)) + if err != nil { + return Pair{}, err + } + b, err := io.ReadAll(rk) + if err != nil { + return Pair{}, err + } + dnskey, err := dns.NewRR(string(b)) + if err != nil { + return Pair{}, err + } + if _, ok := dnskey.(*dns.DNSKEY); !ok { + return Pair{}, fmt.Errorf("RR in %q is not a DNSKEY: %d", public, dnskey.Header().Rrtype) + } + ksk := dnskey.(*dns.DNSKEY).Flags&(1<<8) == (1<<8) && dnskey.(*dns.DNSKEY).Flags&1 == 1 + if !ksk { + return Pair{}, fmt.Errorf("DNSKEY in %q is not a CSK/KSK", public) + } + + rp, err := os.Open(filepath.Clean(private)) + if err != nil { + return Pair{}, err + } + privkey, err := dnskey.(*dns.DNSKEY).ReadPrivateKey(rp, private) + if err != nil { + return Pair{}, err + } + switch signer := privkey.(type) { + case *ecdsa.PrivateKey: + return Pair{Public: dnskey.(*dns.DNSKEY), KeyTag: dnskey.(*dns.DNSKEY).KeyTag(), Private: signer}, nil + case ed25519.PrivateKey: + return Pair{Public: dnskey.(*dns.DNSKEY), KeyTag: dnskey.(*dns.DNSKEY).KeyTag(), Private: signer}, nil + case *rsa.PrivateKey: + return Pair{Public: dnskey.(*dns.DNSKEY), KeyTag: dnskey.(*dns.DNSKEY).KeyTag(), Private: signer}, nil + default: + return Pair{}, fmt.Errorf("unsupported algorithm %s", signer) + } +} + +// keyTag returns the key tags of the keys in ps as a formatted string. +func keyTag(ps []Pair) string { + if len(ps) == 0 { + return "" + } + s := "" + for _, p := range ps { + s += strconv.Itoa(int(p.KeyTag)) + "," + } + return s[:len(s)-1] +} diff --git a/plugin/sign/log_test.go b/plugin/sign/log_test.go new file mode 100644 index 0000000..2726cd1 --- /dev/null +++ b/plugin/sign/log_test.go @@ -0,0 +1,5 @@ +package sign + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/sign/nsec.go b/plugin/sign/nsec.go new file mode 100644 index 0000000..d7c6a30 --- /dev/null +++ b/plugin/sign/nsec.go @@ -0,0 +1,36 @@ +package sign + +import ( + "sort" + + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/file/tree" + + "github.com/miekg/dns" +) + +// names returns the elements of the zone in nsec order. +func names(origin string, z *file.Zone) []string { + // There will also be apex records other than NS and SOA (who are kept separate), as we + // are adding DNSKEY and CDS/CDNSKEY records in the apex *before* we sign. + n := []string{} + z.AuthWalk(func(e *tree.Elem, _ map[uint16][]dns.RR, auth bool) error { + if !auth { + return nil + } + n = append(n, e.Name()) + return nil + }) + return n +} + +// NSEC returns an NSEC record according to name, next, ttl and bitmap. Note that the bitmap is sorted before use. +func NSEC(name, next string, ttl uint32, bitmap []uint16) *dns.NSEC { + sort.Slice(bitmap, func(i, j int) bool { return bitmap[i] < bitmap[j] }) + + return &dns.NSEC{ + Hdr: dns.RR_Header{Name: name, Ttl: ttl, Rrtype: dns.TypeNSEC, Class: dns.ClassINET}, + NextDomain: next, + TypeBitMap: bitmap, + } +} diff --git a/plugin/sign/nsec_test.go b/plugin/sign/nsec_test.go new file mode 100644 index 0000000..f272651 --- /dev/null +++ b/plugin/sign/nsec_test.go @@ -0,0 +1,27 @@ +package sign + +import ( + "os" + "testing" + + "github.com/coredns/coredns/plugin/file" +) + +func TestNames(t *testing.T) { + f, err := os.Open("testdata/db.miek.nl_ns") + if err != nil { + t.Error(err) + } + z, err := file.Parse(f, "db.miek.nl_ns", "miek.nl", 0) + if err != nil { + t.Error(err) + } + + names := names("miek.nl.", z) + expected := []string{"miek.nl.", "child.miek.nl.", "www.miek.nl."} + for i := range names { + if names[i] != expected[i] { + t.Errorf("Expected %s, got %s", expected[i], names[i]) + } + } +} diff --git a/plugin/sign/resign_test.go b/plugin/sign/resign_test.go new file mode 100644 index 0000000..2f67f52 --- /dev/null +++ b/plugin/sign/resign_test.go @@ -0,0 +1,40 @@ +package sign + +import ( + "strings" + "testing" + "time" +) + +func TestResignInception(t *testing.T) { + then := time.Date(2019, 7, 18, 22, 50, 0, 0, time.UTC) + // signed yesterday + zr := strings.NewReader(`miek.nl. 1800 IN RRSIG SOA 13 2 1800 20190808191936 20190717161936 59725 miek.nl. eU6gI1OkSEbyt`) + if x := resign(zr, then); x != nil { + t.Errorf("Expected RRSIG to be valid for %s, got invalid: %s", then.Format(timeFmt), x) + } + // inception starts after this date. + zr = strings.NewReader(`miek.nl. 1800 IN RRSIG SOA 13 2 1800 20190808191936 20190731161936 59725 miek.nl. eU6gI1OkSEbyt`) + if x := resign(zr, then); x == nil { + t.Errorf("Expected RRSIG to be invalid for %s, got valid", then.Format(timeFmt)) + } +} + +func TestResignExpire(t *testing.T) { + then := time.Date(2019, 7, 18, 22, 50, 0, 0, time.UTC) + // expires tomorrow + zr := strings.NewReader(`miek.nl. 1800 IN RRSIG SOA 13 2 1800 20190717191936 20190717161936 59725 miek.nl. eU6gI1OkSEbyt`) + if x := resign(zr, then); x == nil { + t.Errorf("Expected RRSIG to be invalid for %s, got valid", then.Format(timeFmt)) + } + // expire too far away + zr = strings.NewReader(`miek.nl. 1800 IN RRSIG SOA 13 2 1800 20190731191936 20190717161936 59725 miek.nl. eU6gI1OkSEbyt`) + if x := resign(zr, then); x != nil { + t.Errorf("Expected RRSIG to be valid for %s, got invalid: %s", then.Format(timeFmt), x) + } + // expired yesterday + zr = strings.NewReader(`miek.nl. 1800 IN RRSIG SOA 13 2 1800 20190721191936 20190717161936 59725 miek.nl. eU6gI1OkSEbyt`) + if x := resign(zr, then); x == nil { + t.Errorf("Expected RRSIG to be invalid for %s, got valid", then.Format(timeFmt)) + } +} diff --git a/plugin/sign/setup.go b/plugin/sign/setup.go new file mode 100644 index 0000000..e5f5295 --- /dev/null +++ b/plugin/sign/setup.go @@ -0,0 +1,100 @@ +package sign + +import ( + "fmt" + "math/rand" + "path/filepath" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("sign", setup) } + +func setup(c *caddy.Controller) error { + sign, err := parse(c) + if err != nil { + return plugin.Error("sign", err) + } + + c.OnStartup(sign.OnStartup) + c.OnStartup(func() error { + for _, signer := range sign.signers { + go signer.refresh(durationRefreshHours) + } + return nil + }) + c.OnShutdown(func() error { + for _, signer := range sign.signers { + close(signer.stop) + } + return nil + }) + + // Don't call AddPlugin, *sign* is not a plugin. + return nil +} + +func parse(c *caddy.Controller) (*Sign, error) { + sign := &Sign{} + config := dnsserver.GetConfig(c) + + for c.Next() { + if !c.NextArg() { + return nil, c.ArgErr() + } + dbfile := c.Val() + if !filepath.IsAbs(dbfile) && config.Root != "" { + dbfile = filepath.Join(config.Root, dbfile) + } + + origins := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + signers := make([]*Signer, len(origins)) + for i := range origins { + signers[i] = &Signer{ + dbfile: dbfile, + origin: origins[i], + jitterIncep: time.Duration(float32(durationInceptionJitter) * rand.Float32()), + jitterExpir: time.Duration(float32(durationExpirationDayJitter) * rand.Float32()), + directory: "/var/lib/coredns", + stop: make(chan struct{}), + signedfile: fmt.Sprintf("db.%ssigned", origins[i]), // origins[i] is a fqdn, so it ends with a dot, hence %ssigned. + } + } + + for c.NextBlock() { + switch c.Val() { + case "key": + pairs, err := keyParse(c) + if err != nil { + return sign, err + } + for i := range signers { + for _, p := range pairs { + p.Public.Header().Name = signers[i].origin + } + signers[i].keys = append(signers[i].keys, pairs...) + } + case "directory": + dir := c.RemainingArgs() + if len(dir) == 0 || len(dir) > 1 { + return sign, fmt.Errorf("can only be one argument after %q", "directory") + } + if !filepath.IsAbs(dir[0]) && config.Root != "" { + dir[0] = filepath.Join(config.Root, dir[0]) + } + for i := range signers { + signers[i].directory = dir[0] + signers[i].signedfile = fmt.Sprintf("db.%ssigned", signers[i].origin) + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + sign.signers = append(sign.signers, signers...) + } + + return sign, nil +} diff --git a/plugin/sign/setup_test.go b/plugin/sign/setup_test.go new file mode 100644 index 0000000..93d779a --- /dev/null +++ b/plugin/sign/setup_test.go @@ -0,0 +1,75 @@ +package sign + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestParse(t *testing.T) { + tests := []struct { + input string + shouldErr bool + exp *Signer + }{ + {`sign testdata/db.miek.nl miek.nl { + key file testdata/Kmiek.nl.+013+59725 + }`, + false, + &Signer{ + keys: []Pair{}, + origin: "miek.nl.", + dbfile: "testdata/db.miek.nl", + directory: "/var/lib/coredns", + signedfile: "db.miek.nl.signed", + }, + }, + {`sign testdata/db.miek.nl example.org { + key file testdata/Kmiek.nl.+013+59725 + directory testdata + }`, + false, + &Signer{ + keys: []Pair{}, + origin: "example.org.", + dbfile: "testdata/db.miek.nl", + directory: "testdata", + signedfile: "db.example.org.signed", + }, + }, + // errors + {`sign db.example.org { + key file /etc/coredns/keys/Kexample.org + }`, + true, + nil, + }, + } + for i, tc := range tests { + c := caddy.NewTestController("dns", tc.input) + sign, err := parse(c) + + if err == nil && tc.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } + if err != nil && !tc.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } + if tc.shouldErr { + continue + } + signer := sign.signers[0] + if x := signer.origin; x != tc.exp.origin { + t.Errorf("Test %d expected %s as origin, got %s", i, tc.exp.origin, x) + } + if x := signer.dbfile; x != tc.exp.dbfile { + t.Errorf("Test %d expected %s as dbfile, got %s", i, tc.exp.dbfile, x) + } + if x := signer.directory; x != tc.exp.directory { + t.Errorf("Test %d expected %s as directory, got %s", i, tc.exp.directory, x) + } + if x := signer.signedfile; x != tc.exp.signedfile { + t.Errorf("Test %d expected %s as signedfile, got %s", i, tc.exp.signedfile, x) + } + } +} diff --git a/plugin/sign/sign.go b/plugin/sign/sign.go new file mode 100644 index 0000000..982d700 --- /dev/null +++ b/plugin/sign/sign.go @@ -0,0 +1,38 @@ +// Package sign implements a zone signer as a plugin. +package sign + +import ( + "path/filepath" + "time" +) + +// Sign contains signers that sign the zones files. +type Sign struct { + signers []*Signer +} + +// OnStartup scans all signers and signs or resigns zones if needed. +func (s *Sign) OnStartup() error { + for _, signer := range s.signers { + why := signer.resign() + if why == nil { + log.Infof("Skipping signing zone %q in %q: signatures are valid", signer.origin, filepath.Join(signer.directory, signer.signedfile)) + continue + } + go signAndLog(signer, why) + } + return nil +} + +// Various duration constants for signing of the zones. +const ( + durationExpireDays = 7 * 24 * time.Hour // max time allowed before expiration + durationResignDays = 6 * 24 * time.Hour // if the last sign happened this long ago, sign again + durationSignatureExpireDays = 32 * 24 * time.Hour // sign for 32 days + durationRefreshHours = 5 * time.Hour // check zones every 5 hours + durationInceptionJitter = -18 * time.Hour // default max jitter for the inception + durationExpirationDayJitter = 5 * 24 * time.Hour // default max jitter for the expiration + durationSignatureInceptionHours = -3 * time.Hour // -(2+1) hours, be sure to catch daylight saving time and such, jitter is subtracted +) + +const timeFmt = "2006-01-02T15:04:05.000Z07:00" diff --git a/plugin/sign/signer.go b/plugin/sign/signer.go new file mode 100644 index 0000000..95ce94b --- /dev/null +++ b/plugin/sign/signer.go @@ -0,0 +1,210 @@ +package sign + +import ( + "fmt" + "io" + "os" + "path/filepath" + "time" + + "github.com/coredns/coredns/plugin/file" + "github.com/coredns/coredns/plugin/file/tree" + clog "github.com/coredns/coredns/plugin/pkg/log" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("sign") + +// Signer holds the data needed to sign a zone file. +type Signer struct { + keys []Pair + origin string + dbfile string + directory string + jitterIncep time.Duration + jitterExpir time.Duration + + signedfile string + stop chan struct{} +} + +// Sign signs a zone file according to the parameters in s. +func (s *Signer) Sign(now time.Time) (*file.Zone, error) { + rd, err := os.Open(s.dbfile) + if err != nil { + return nil, err + } + + z, err := Parse(rd, s.origin, s.dbfile) + if err != nil { + return nil, err + } + + mttl := z.Apex.SOA.Minttl + ttl := z.Apex.SOA.Header().Ttl + inception, expiration := lifetime(now, s.jitterIncep, s.jitterExpir) + z.Apex.SOA.Serial = uint32(now.Unix()) + + for _, pair := range s.keys { + pair.Public.Header().Ttl = ttl // set TTL on key so it matches the RRSIG. + z.Insert(pair.Public) + z.Insert(pair.Public.ToDS(dns.SHA1).ToCDS()) + z.Insert(pair.Public.ToDS(dns.SHA256).ToCDS()) + z.Insert(pair.Public.ToCDNSKEY()) + } + + names := names(s.origin, z) + ln := len(names) + + for _, pair := range s.keys { + rrsig, err := pair.signRRs([]dns.RR{z.Apex.SOA}, s.origin, ttl, inception, expiration) + if err != nil { + return nil, err + } + z.Insert(rrsig) + // NS apex may not be set if RR's have been discarded because the origin doesn't match. + if len(z.Apex.NS) > 0 { + rrsig, err = pair.signRRs(z.Apex.NS, s.origin, ttl, inception, expiration) + if err != nil { + return nil, err + } + z.Insert(rrsig) + } + } + + // We are walking the tree in the same direction, so names[] can be used here to indicated the next element. + i := 1 + err = z.AuthWalk(func(e *tree.Elem, zrrs map[uint16][]dns.RR, auth bool) error { + if !auth { + return nil + } + + if e.Name() == s.origin { + nsec := NSEC(e.Name(), names[(ln+i)%ln], mttl, append(e.Types(), dns.TypeNS, dns.TypeSOA, dns.TypeRRSIG, dns.TypeNSEC)) + z.Insert(nsec) + } else { + nsec := NSEC(e.Name(), names[(ln+i)%ln], mttl, append(e.Types(), dns.TypeRRSIG, dns.TypeNSEC)) + z.Insert(nsec) + } + + for t, rrs := range zrrs { + // RRSIGs are not signed and NS records are not signed because we are never authoratiative for them. + // The zone's apex nameservers records are not kept in this tree and are signed separately. + if t == dns.TypeRRSIG || t == dns.TypeNS { + continue + } + for _, pair := range s.keys { + rrsig, err := pair.signRRs(rrs, s.origin, rrs[0].Header().Ttl, inception, expiration) + if err != nil { + return err + } + e.Insert(rrsig) + } + } + i++ + return nil + }) + return z, err +} + +// resign checks if the signed zone exists, or needs resigning. +func (s *Signer) resign() error { + signedfile := filepath.Join(s.directory, s.signedfile) + rd, err := os.Open(filepath.Clean(signedfile)) + if err != nil && os.IsNotExist(err) { + return err + } + + now := time.Now().UTC() + return resign(rd, now) +} + +// resign will scan rd and check the signature on the SOA record. We will resign on the basis +// of 2 conditions: +// * either the inception is more than 6 days ago, or +// * we only have 1 week left on the signature +// +// All SOA signatures will be checked. If the SOA isn't found in the first 100 +// records, we will resign the zone. +func resign(rd io.Reader, now time.Time) (why error) { + zp := dns.NewZoneParser(rd, ".", "resign") + zp.SetIncludeAllowed(true) + i := 0 + + for rr, ok := zp.Next(); ok; rr, ok = zp.Next() { + if err := zp.Err(); err != nil { + return err + } + + switch x := rr.(type) { + case *dns.RRSIG: + if x.TypeCovered != dns.TypeSOA { + continue + } + incep, _ := time.Parse("20060102150405", dns.TimeToString(x.Inception)) + // If too long ago, resign. + if now.Sub(incep) >= 0 && now.Sub(incep) > durationResignDays { + return fmt.Errorf("inception %q was more than: %s ago from %s: %s", incep.Format(timeFmt), durationResignDays, now.Format(timeFmt), now.Sub(incep)) + } + // Inception hasn't even start yet. + if now.Sub(incep) < 0 { + return fmt.Errorf("inception %q date is in the future: %s", incep.Format(timeFmt), now.Sub(incep)) + } + + expire, _ := time.Parse("20060102150405", dns.TimeToString(x.Expiration)) + if expire.Sub(now) < durationExpireDays { + return fmt.Errorf("expiration %q is less than: %s away from %s: %s", expire.Format(timeFmt), durationExpireDays, now.Format(timeFmt), expire.Sub(now)) + } + } + i++ + if i > 100 { + // 100 is a random number. A SOA record should be the first in the zonefile, but RFC 1035 doesn't actually mandate this. So it could + // be 3rd or even later. The number 100 looks crazy high enough that it will catch all weird zones, but not high enough to keep the CPU + // busy with parsing all the time. + return fmt.Errorf("no SOA RRSIG found in first 100 records") + } + } + + return nil +} + +func signAndLog(s *Signer, why error) { + now := time.Now().UTC() + z, err := s.Sign(now) + log.Infof("Signing %q because %s", s.origin, why) + if err != nil { + log.Warningf("Error signing %q with key tags %q in %s: %s, next: %s", s.origin, keyTag(s.keys), time.Since(now), err, now.Add(durationRefreshHours).Format(timeFmt)) + return + } + + if err := s.write(z); err != nil { + log.Warningf("Error signing %q: failed to move zone file into place: %s", s.origin, err) + return + } + log.Infof("Successfully signed zone %q in %q with key tags %q and %d SOA serial, elapsed %f, next: %s", s.origin, filepath.Join(s.directory, s.signedfile), keyTag(s.keys), z.Apex.SOA.Serial, time.Since(now).Seconds(), now.Add(durationRefreshHours).Format(timeFmt)) +} + +// refresh checks every val if some zones need to be resigned. +func (s *Signer) refresh(val time.Duration) { + tick := time.NewTicker(val) + defer tick.Stop() + for { + select { + case <-s.stop: + return + case <-tick.C: + why := s.resign() + if why == nil { + continue + } + signAndLog(s, why) + } + } +} + +func lifetime(now time.Time, jitterInception, jitterExpiration time.Duration) (uint32, uint32) { + incep := uint32(now.Add(durationSignatureInceptionHours).Add(jitterInception).Unix()) + expir := uint32(now.Add(durationSignatureExpireDays).Add(jitterExpiration).Unix()) + return incep, expir +} diff --git a/plugin/sign/signer_test.go b/plugin/sign/signer_test.go new file mode 100644 index 0000000..17f11ab --- /dev/null +++ b/plugin/sign/signer_test.go @@ -0,0 +1,177 @@ +package sign + +import ( + "os" + "testing" + "time" + + "github.com/coredns/caddy" + + "github.com/miekg/dns" +) + +func TestSign(t *testing.T) { + input := `sign testdata/db.miek.nl miek.nl { + key file testdata/Kmiek.nl.+013+59725 + directory testdata + }` + c := caddy.NewTestController("dns", input) + sign, err := parse(c) + if err != nil { + t.Fatal(err) + } + if len(sign.signers) != 1 { + t.Fatalf("Expected 1 signer, got %d", len(sign.signers)) + } + z, err := sign.signers[0].Sign(time.Now().UTC()) + if err != nil { + t.Error(err) + } + + apex, _ := z.Search("miek.nl.") + if x := apex.Type(dns.TypeDS); len(x) != 0 { + t.Errorf("Expected %d DS records, got %d", 0, len(x)) + } + if x := apex.Type(dns.TypeCDS); len(x) != 2 { + t.Errorf("Expected %d CDS records, got %d", 2, len(x)) + } + if x := apex.Type(dns.TypeCDNSKEY); len(x) != 1 { + t.Errorf("Expected %d CDNSKEY record, got %d", 1, len(x)) + } + if x := apex.Type(dns.TypeDNSKEY); len(x) != 1 { + t.Errorf("Expected %d DNSKEY record, got %d", 1, len(x)) + } +} + +func TestSignApexZone(t *testing.T) { + apex := `$TTL 30M +$ORIGIN example.org. +@ IN SOA linode miek.miek.nl. ( 1282630060 4H 1H 7D 4H ) + IN NS linode +` + if err := os.WriteFile("db.apex-test.example.org", []byte(apex), 0644); err != nil { + t.Fatal(err) + } + defer os.Remove("db.apex-test.example.org") + input := `sign db.apex-test.example.org example.org { + key file testdata/Kmiek.nl.+013+59725 + directory testdata + }` + c := caddy.NewTestController("dns", input) + sign, err := parse(c) + if err != nil { + t.Fatal(err) + } + z, err := sign.signers[0].Sign(time.Now().UTC()) + if err != nil { + t.Error(err) + } + + el, _ := z.Search("example.org.") + nsec := el.Type(dns.TypeNSEC) + if len(nsec) != 1 { + t.Errorf("Expected 1 NSEC for %s, got %d", "example.org.", len(nsec)) + } + if x := nsec[0].(*dns.NSEC).NextDomain; x != "example.org." { + t.Errorf("Expected NSEC NextDomain %s, got %s", "example.org.", x) + } + if x := nsec[0].(*dns.NSEC).TypeBitMap; len(x) != 7 { + t.Errorf("Expected NSEC bitmap to be %d elements, got %d", 7, x) + } + if x := nsec[0].(*dns.NSEC).TypeBitMap; x[6] != dns.TypeCDNSKEY { + t.Errorf("Expected NSEC bitmap element 5 to be %d, got %d", dns.TypeCDNSKEY, x[6]) + } + if x := nsec[0].(*dns.NSEC).TypeBitMap; x[4] != dns.TypeDNSKEY { + t.Errorf("Expected NSEC bitmap element 4 to be %d, got %d", dns.TypeDNSKEY, x[4]) + } + dnskey := el.Type(dns.TypeDNSKEY) + if x := dnskey[0].Header().Ttl; x != 1800 { + t.Errorf("Expected DNSKEY TTL to be %d, got %d", 1800, x) + } + sigs := el.Type(dns.TypeRRSIG) + for _, s := range sigs { + if s.(*dns.RRSIG).TypeCovered == dns.TypeDNSKEY { + if s.(*dns.RRSIG).OrigTtl != dnskey[0].Header().Ttl { + t.Errorf("Expected RRSIG original TTL to match DNSKEY TTL, but %d != %d", s.(*dns.RRSIG).OrigTtl, dnskey[0].Header().Ttl) + } + if s.(*dns.RRSIG).SignerName != dnskey[0].Header().Name { + t.Errorf("Expected RRSIG signer name to match DNSKEY ownername, but %s != %s", s.(*dns.RRSIG).SignerName, dnskey[0].Header().Name) + } + } + } +} + +func TestSignGlue(t *testing.T) { + input := `sign testdata/db.miek.nl miek.nl { + key file testdata/Kmiek.nl.+013+59725 + directory testdata + }` + c := caddy.NewTestController("dns", input) + sign, err := parse(c) + if err != nil { + t.Fatal(err) + } + if len(sign.signers) != 1 { + t.Fatalf("Expected 1 signer, got %d", len(sign.signers)) + } + z, err := sign.signers[0].Sign(time.Now().UTC()) + if err != nil { + t.Error(err) + } + + e, _ := z.Search("ns2.bla.miek.nl.") + sigs := e.Type(dns.TypeRRSIG) + if len(sigs) != 0 { + t.Errorf("Expected no RRSIG for %s, got %d", "ns2.bla.miek.nl.", len(sigs)) + } +} + +func TestSignDS(t *testing.T) { + input := `sign testdata/db.miek.nl_ns miek.nl { + key file testdata/Kmiek.nl.+013+59725 + directory testdata + }` + c := caddy.NewTestController("dns", input) + sign, err := parse(c) + if err != nil { + t.Fatal(err) + } + if len(sign.signers) != 1 { + t.Fatalf("Expected 1 signer, got %d", len(sign.signers)) + } + z, err := sign.signers[0].Sign(time.Now().UTC()) + if err != nil { + t.Error(err) + } + + // dnssec-signzone outputs this for db.miek.nl_ns: + // + // child.miek.nl. 1800 IN NS ns.child.miek.nl. + // child.miek.nl. 1800 IN DS 34385 13 2 fc7397c77afbccb6742fc.... + // child.miek.nl. 1800 IN RRSIG DS 13 3 1800 20191223121229 20191123121229 59725 miek.nl. ZwptLzVVs.... + // child.miek.nl. 14400 IN NSEC www.miek.nl. NS DS RRSIG NSEC + // child.miek.nl. 14400 IN RRSIG NSEC 13 3 14400 20191223121229 20191123121229 59725 miek.nl. w+CcA8... + + name := "child.miek.nl." + e, _ := z.Search(name) + if x := len(e.Types()); x != 4 { // NS DS NSEC and 2x RRSIG + t.Errorf("Expected 4 records for %s, got %d", name, x) + } + + ds := e.Type(dns.TypeDS) + if len(ds) != 1 { + t.Errorf("Expected DS for %s, got %d", name, len(ds)) + } + sigs := e.Type(dns.TypeRRSIG) + if len(sigs) != 2 { + t.Errorf("Expected no RRSIG for %s, got %d", name, len(sigs)) + } + nsec := e.Type(dns.TypeNSEC) + if x := nsec[0].(*dns.NSEC).NextDomain; x != "www.miek.nl." { + t.Errorf("Expected no NSEC NextDomain to be %s for %s, got %s", "www.miek.nl.", name, x) + } + minttl := z.Apex.SOA.Minttl + if x := nsec[0].Header().Ttl; x != minttl { + t.Errorf("Expected no NSEC TTL to be %d for %s, got %d", minttl, "www.miek.nl.", x) + } +} diff --git a/plugin/sign/testdata/Kmiek.nl.+013+59725.key b/plugin/sign/testdata/Kmiek.nl.+013+59725.key new file mode 100644 index 0000000..b3e3654 --- /dev/null +++ b/plugin/sign/testdata/Kmiek.nl.+013+59725.key @@ -0,0 +1,5 @@ +; This is a key-signing key, keyid 59725, for miek.nl. +; Created: 20190709192036 (Tue Jul 9 20:20:36 2019) +; Publish: 20190709192036 (Tue Jul 9 20:20:36 2019) +; Activate: 20190709192036 (Tue Jul 9 20:20:36 2019) +miek.nl. IN DNSKEY 257 3 13 sfzRg5nDVxbeUc51su4MzjgwpOpUwnuu81SlRHqJuXe3SOYOeypR69tZ 52XLmE56TAmPHsiB8Rgk+NTpf0o1Cw== diff --git a/plugin/sign/testdata/Kmiek.nl.+013+59725.private b/plugin/sign/testdata/Kmiek.nl.+013+59725.private new file mode 100644 index 0000000..2545ed9 --- /dev/null +++ b/plugin/sign/testdata/Kmiek.nl.+013+59725.private @@ -0,0 +1,6 @@ +Private-key-format: v1.3 +Algorithm: 13 (ECDSAP256SHA256) +PrivateKey: rm7EdHRca//6xKpJzeoLt/mrfgQnltJ0WpQGtOG59yo= +Created: 20190709192036 +Publish: 20190709192036 +Activate: 20190709192036 diff --git a/plugin/sign/testdata/db.miek.nl b/plugin/sign/testdata/db.miek.nl new file mode 100644 index 0000000..4041b1b --- /dev/null +++ b/plugin/sign/testdata/db.miek.nl @@ -0,0 +1,17 @@ +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( 1282630060 4H 1H 7D 4H ) + IN NS linode.atoom.net. + IN MX 1 aspmx.l.google.com. + IN AAAA 2a01:7e00::f03c:91ff:fe79:234c + IN DNSKEY 257 3 13 sfzRg5nDVxbeUc51su4MzjgwpOpUwnuu81SlRHqJuXe3SOYOeypR69tZ52XLmE56TAmPHsiB8Rgk+NTpf0o1Cw== + +a IN AAAA 2a01:7e00::f03c:91ff:fe79:234c +www IN CNAME a + + +bla IN NS ns1.bla.com. +ns3.blaaat.miek.nl. IN AAAA ::1 ; non-glue, should be signed. +; in baliwick nameserver that requires glue, should not be signed +bla IN NS ns2.bla.miek.nl. +ns2.bla.miek.nl. IN A 127.0.0.1 diff --git a/plugin/sign/testdata/db.miek.nl_ns b/plugin/sign/testdata/db.miek.nl_ns new file mode 100644 index 0000000..bd2371f --- /dev/null +++ b/plugin/sign/testdata/db.miek.nl_ns @@ -0,0 +1,10 @@ +$TTL 30M +$ORIGIN miek.nl. +@ IN SOA linode.atoom.net. miek.miek.nl. ( 1282630060 4H 1H 7D 4H ) + NS linode.atoom.net. + DNSKEY 257 3 13 sfzRg5nDVxbeUc51su4MzjgwpOpUwnuu81SlRHqJuXe3SOYOeypR69tZ52XLmE56TAmPHsiB8Rgk+NTpf0o1Cw== + +www AAAA ::1 +child NS ns.child +ns.child AAAA ::1 +child DS 34385 13 2 fc7397c77afbccb6742fcff19c7b1410d0044661e7085fc200ae1ab3d15a5842 diff --git a/plugin/template/README.md b/plugin/template/README.md new file mode 100644 index 0000000..1bca906 --- /dev/null +++ b/plugin/template/README.md @@ -0,0 +1,298 @@ +# template + +## Name + +*template* - allows for dynamic responses based on the incoming query. + +## Description + +The *template* plugin allows you to dynamically respond to queries by just writing a (Go) template. + +## Syntax + +~~~ +template CLASS TYPE [ZONE...] { + match REGEX... + answer RR + additional RR + authority RR + rcode CODE + ederror EXTENDED_ERROR_CODE [EXTRA_REASON] + fallthrough [FALLTHROUGH-ZONE...] +} +~~~ + +* **CLASS** the query class (usually IN or ANY). +* **TYPE** the query type (A, PTR, ... can be ANY to match all types). +* **ZONE** the zone scope(s) for this template. Defaults to the server zones. +* `match` **REGEX** [Go regexp](https://golang.org/pkg/regexp/) that are matched against the incoming question name. + Specifying no regex matches everything (default: `.*`). First matching regex wins. +* `answer|additional|authority` **RR** A [RFC 1035](https://tools.ietf.org/html/rfc1035#section-5) style resource record fragment + built by a [Go template](https://golang.org/pkg/text/template/) that contains the reply. Specifying no answer will result + in a response with an empty answer section. +* `rcode` **CODE** A response code (`NXDOMAIN, SERVFAIL, ...`). The default is `NOERROR`. Valid response code values are + per the `RcodeToString` map defined by the `miekg/dns` package in `msg.go`. +* `ederror` **EXTENDED_ERROR_CODE** is an extended DNS error code as a number defined in `RFC8914` (0, 1, 2,..., 24). + **EXTRA_REASON** is an additional string explaining the reason for returning the error. +* `fallthrough` Continue with the next _template_ instance if the _template_'s **ZONE** matches a query name but no regex match. + If there is no next _template_, continue resolution with the next plugin. If **[FALLTHROUGH-ZONE...]** are listed (for example + `in-addr.arpa` and `ip6.arpa`), then only queries for those zones will be subject to fallthrough. Without + `fallthrough`, when the _template_'s **ZONE** matches a query but no regex match then a `SERVFAIL` response is returned. + +[Also see](#also-see) contains an additional reading list. + +## Templates + +Each resource record is a full-featured [Go template](https://golang.org/pkg/text/template/) with the following predefined data + +* `.Zone` the matched zone string (e.g. `example.`). +* `.Name` the query name, as a string (lowercased). +* `.Class` the query class (usually `IN`). +* `.Type` the RR type requested (e.g. `PTR`). +* `.Match` an array of all matches. `index .Match 0` refers to the whole match. +* `.Group` a map of the named capture groups. +* `.Message` the complete incoming DNS message. +* `.Question` the matched question section. +* `.Remote` client’s IP address +* `.Meta` a function that takes a metadata name and returns the value, if the + metadata plugin is enabled. For example, `.Meta "kubernetes/client-namespace"` + +and the following predefined [template functions](https://golang.org/pkg/text/template#hdr-Functions) + +* `parseInt` interprets a string in the given base and bit size. Equivalent to [strconv.ParseUint](https://golang.org/pkg/strconv#ParseUint). + +The output of the template must be a [RFC 1035](https://tools.ietf.org/html/rfc1035) style resource record (commonly referred to as a "zone file"). + +**WARNING** there is a syntactical problem with Go templates and CoreDNS config files. Expressions + like `{{$var}}` will be interpreted as a reference to an environment variable by CoreDNS (and + Caddy) while `{{ $var }}` will work. See [Bugs](#bugs) and corefile(5). + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metrics are exported: + +* `coredns_template_matches_total{server, zone, view, class, type}` the total number of matched requests by regex. +* `coredns_template_template_failures_total{server, zone, view, class, type, section, template}` the number of times the Go templating failed. Regex, section and template label values can be used to map the error back to the config file. +* `coredns_template_rr_failures_total{server, zone, view, class, type, section, template}` the number of times the templated resource record was invalid and could not be parsed. Regex, section and template label values can be used to map the error back to the config file. + +Both failure cases indicate a problem with the template configuration. The `server` label indicates +the server incrementing the metric, see the *metrics* plugin for details. + +## Examples + +### Resolve everything to NXDOMAIN + +The most simplistic template is + +~~~ corefile +. { + template ANY ANY { + rcode NXDOMAIN + } +} +~~~ + +1. This template uses the default zone (`.` or all queries) +2. All queries will be answered (no `fallthrough`) +3. The answer is always NXDOMAIN + +### Resolve .invalid as NXDOMAIN + +The `.invalid` domain is a reserved TLD (see [RFC 2606 Reserved Top Level DNS Names](https://tools.ietf.org/html/rfc2606#section-2)) to indicate invalid domains. + +~~~ corefile +. { + forward . 8.8.8.8 + + template ANY ANY invalid { + rcode NXDOMAIN + authority "invalid. 60 {{ .Class }} SOA ns.invalid. hostmaster.invalid. (1 60 60 60 60)" + ederror 21 "Blocked according to RFC2606" + } +} +~~~ + +1. A query to .invalid will result in NXDOMAIN (rcode) +2. A dummy SOA record is sent to hand out a TTL of 60s for caching purposes +3. Querying `.invalid` in the `CH` class will also cause a NXDOMAIN/SOA response +4. The default regex is `.*` + +### Block invalid search domain completions + +Imagine you run `example.com` with a datacenter `dc1.example.com`. The datacenter domain +is part of the DNS search domain. +However `something.example.com.dc1.example.com` would indicate a fully qualified +domain name (`something.example.com`) that inadvertently has the default domain or search +path (`dc1.example.com`) added. + +~~~ corefile +. { + forward . 8.8.8.8 + + template IN ANY example.com.dc1.example.com { + rcode NXDOMAIN + authority "{{ .Zone }} 60 IN SOA ns.example.com hostmaster.example.com (1 60 60 60 60)" + } +} +~~~ + +A more verbose regex based equivalent would be + +~~~ corefile +. { + forward . 8.8.8.8 + + template IN ANY example.com { + match "example\.com\.(dc1\.example\.com\.)$" + rcode NXDOMAIN + authority "{{ index .Match 1 }} 60 IN SOA ns.{{ index .Match 1 }} hostmaster.{{ index .Match 1 }} (1 60 60 60 60)" + fallthrough + } +} +~~~ + +The regex-based version can do more complex matching/templating while zone-based templating is easier to read and use. + +### Resolve A/PTR for .example + +~~~ corefile +. { + forward . 8.8.8.8 + + # ip-a-b-c-d.example A a.b.c.d + + template IN A example { + match (^|[.])ip-(?P<a>[0-9]*)-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN A {{ .Group.a }}.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + fallthrough + } + + # d.c.b.a.in-addr.arpa PTR ip-a-b-c-d.example + + template IN PTR in-addr.arpa { + match ^(?P<d>[0-9]*)[.](?P<c>[0-9]*)[.](?P<b>[0-9]*)[.](?P<a>[0-9]*)[.]in-addr[.]arpa[.]$ + answer "{{ .Name }} 60 IN PTR ip-{{ .Group.a }}-{{ .Group.b }}-{{ .Group.c }}-{{ .Group.d }}.example." + } +} +~~~ + +An IPv4 address consists of 4 bytes, `a.b.c.d`. Named groups make it less error-prone to reverse the +IP address in the PTR case. Try to use named groups to explain what your regex and template are doing. + +Note that the A record is actually a wildcard: any subdomain of the IP address will resolve to the IP address. + +Having templates to map certain PTR/A pairs is a common pattern. + +Fallthrough is needed for mixed domains where only some responses are templated. + +### Resolve hexadecimal ip pattern using parseInt + +~~~ corefile +. { + forward . 8.8.8.8 + + template IN A example { + match "^ip0a(?P<b>[a-f0-9]{2})(?P<c>[a-f0-9]{2})(?P<d>[a-f0-9]{2})[.]example[.]$" + answer "{{ .Name }} 60 IN A 10.{{ parseInt .Group.b 16 8 }}.{{ parseInt .Group.c 16 8 }}.{{ parseInt .Group.d 16 8 }}" + fallthrough + } +} +~~~ + +An IPv4 address can be expressed in a more compact form using its hexadecimal encoding. +For example `ip-10-123-123.example.` can instead be expressed as `ip0a7b7b7b.example.` + +### Resolve multiple ip patterns + +~~~ corefile +. { + forward . 8.8.8.8 + + template IN A example { + match "^ip-(?P<a>10)-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]dc[.]example[.]$" + match "^(?P<a>[0-9]*)[.](?P<b>[0-9]*)[.](?P<c>[0-9]*)[.](?P<d>[0-9]*)[.]ext[.]example[.]$" + answer "{{ .Name }} 60 IN A {{ .Group.a}}.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + fallthrough + } +} +~~~ + +Named capture groups can be used to template one response for multiple patterns. + +### Resolve A and MX records for IP templates in .example + +~~~ corefile +. { + forward . 8.8.8.8 + + template IN A example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + fallthrough + } + template IN MX example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN MX 10 {{ .Name }}" + additional "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + fallthrough + } +} +~~~ + +### Adding authoritative nameservers to the response + +~~~ corefile +. { + forward . 8.8.8.8 + + template IN A example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + authority "example. 60 IN NS ns0.example." + authority "example. 60 IN NS ns1.example." + additional "ns0.example. 60 IN A 203.0.113.8" + additional "ns1.example. 60 IN A 198.51.100.8" + fallthrough + } + template IN MX example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN MX 10 {{ .Name }}" + additional "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + authority "example. 60 IN NS ns0.example." + authority "example. 60 IN NS ns1.example." + additional "ns0.example. 60 IN A 203.0.113.8" + additional "ns1.example. 60 IN A 198.51.100.8" + fallthrough + } +} +~~~ + +### Fabricate a CNAME + +This example responds with a CNAME to `google.com` for any DNS query made exactly for `foogle.com`. +The answer will also contain a record for `google.com` if the upstream nameserver can return a record for it of the +requested type. + +~~~ corefile +. { + template IN ANY foogle.com { + match "^foogle\.com\.$" + answer "foogle.com 60 IN CNAME google.com" + } + forward . 8.8.8.8 +} +~~~ + +## Also see + +* [Go regexp](https://golang.org/pkg/regexp/) for details about the regex implementation +* [RE2 syntax reference](https://github.com/google/re2/wiki/Syntax) for details about the regex syntax +* [RFC 1034](https://tools.ietf.org/html/rfc1034#section-3.6.1) and [RFC 1035](https://tools.ietf.org/html/rfc1035#section-5) for the resource record format +* [Go template](https://golang.org/pkg/text/template/) for the template language reference + +## Bugs + +CoreDNS supports [caddyfile environment variables](https://caddyserver.com/docs/caddyfile#env) +with notion of `{$ENV_VAR}`. This parser feature will break [Go template variables](https://golang.org/pkg/text/template/#hdr-Variables) notations like`{{$variable}}`. +The equivalent notation `{{ $variable }}` will work. +Try to avoid Go template variables in the context of this plugin. diff --git a/plugin/template/cname_test.go b/plugin/template/cname_test.go new file mode 100644 index 0000000..eef949e --- /dev/null +++ b/plugin/template/cname_test.go @@ -0,0 +1,96 @@ +package template + +import ( + "context" + "regexp" + "testing" + gotmpl "text/template" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestTruncatedCNAME(t *testing.T) { + up := &Upstub{ + Qclass: dns.ClassINET, + Truncated: true, + Case: test.Case{ + Qname: "cname.test.", + Qtype: dns.TypeA, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.CNAME("cname.test. 600 IN CNAME test.up"), + test.A("test.up. 600 IN A 1.2.3.4"), + }, + }, + } + + handler := Handler{ + Zones: []string{"."}, + Templates: []template{{ + regex: []*regexp.Regexp{regexp.MustCompile(`^cname\.test\.$`)}, + answer: []*gotmpl.Template{gotmpl.Must(gotmpl.New("answer").Parse(up.Answer[0].String()))}, + qclass: dns.ClassINET, + qtype: dns.TypeA, + zones: []string{"test."}, + upstream: up, + }}, + } + + r := &dns.Msg{Question: []dns.Question{{Name: up.Qname, Qclass: up.Qclass, Qtype: up.Qtype}}} + w := dnstest.NewRecorder(&test.ResponseWriter{}) + + _, err := handler.ServeDNS(context.TODO(), w, r) + + if err != nil { + t.Fatalf("Unexpected error %q", err) + } + if w.Msg == nil { + t.Fatalf("Unexpected empty response.") + } + if !w.Msg.Truncated { + t.Error("Expected reply to be marked truncated.") + } + err = test.SortAndCheck(w.Msg, up.Case) + if err != nil { + t.Error(err) + } +} + +// Upstub implements an Upstreamer that returns a set response for test purposes +type Upstub struct { + test.Case + Truncated bool + Qclass uint16 +} + +// Lookup returns a set response +func (t *Upstub) Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) { + var answer []dns.RR + // if query type is not CNAME, remove any CNAME with same name as qname from the answer + if t.Qtype != dns.TypeCNAME { + for _, a := range t.Answer { + if c, ok := a.(*dns.CNAME); ok && c.Header().Name == t.Qname { + continue + } + answer = append(answer, a) + } + } else { + answer = t.Answer + } + + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Truncated: t.Truncated, + Rcode: t.Rcode, + }, + Question: []dns.Question{{Name: t.Qname, Qtype: t.Qtype, Qclass: t.Qclass}}, + Answer: answer, + Extra: t.Extra, + Ns: t.Ns, + }, nil +} diff --git a/plugin/template/log_test.go b/plugin/template/log_test.go new file mode 100644 index 0000000..13d6e6b --- /dev/null +++ b/plugin/template/log_test.go @@ -0,0 +1,5 @@ +package template + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/template/metrics.go b/plugin/template/metrics.go new file mode 100644 index 0000000..6a6912a --- /dev/null +++ b/plugin/template/metrics.go @@ -0,0 +1,32 @@ +package template + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // templateMatchesCount is the counter of template regex matches. + templateMatchesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "template", + Name: "matches_total", + Help: "Counter of template regex matches.", + }, []string{"server", "zone", "view", "class", "type"}) + // templateFailureCount is the counter of go template failures. + templateFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "template", + Name: "template_failures_total", + Help: "Counter of go template failures.", + }, []string{"server", "zone", "view", "class", "type", "section", "template"}) + // templateRRFailureCount is the counter of mis-templated RRs. + templateRRFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "template", + Name: "rr_failures_total", + Help: "Counter of mis-templated RRs.", + }, []string{"server", "zone", "view", "class", "type", "section", "template"}) +) diff --git a/plugin/template/setup.go b/plugin/template/setup.go new file mode 100644 index 0000000..56058f0 --- /dev/null +++ b/plugin/template/setup.go @@ -0,0 +1,162 @@ +package template + +import ( + "regexp" + "strconv" + gotmpl "text/template" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/upstream" + + "github.com/miekg/dns" +) + +func init() { plugin.Register("template", setupTemplate) } + +func setupTemplate(c *caddy.Controller) error { + handler, err := templateParse(c) + if err != nil { + return plugin.Error("template", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + handler.Next = next + return handler + }) + + return nil +} + +func templateParse(c *caddy.Controller) (handler Handler, err error) { + handler.Templates = make([]template, 0) + + for c.Next() { + if !c.NextArg() { + return handler, c.ArgErr() + } + class, ok := dns.StringToClass[c.Val()] + if !ok { + return handler, c.Errf("invalid query class %s", c.Val()) + } + + if !c.NextArg() { + return handler, c.ArgErr() + } + qtype, ok := dns.StringToType[c.Val()] + if !ok { + return handler, c.Errf("invalid RR class %s", c.Val()) + } + + zones := plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + handler.Zones = append(handler.Zones, zones...) + t := template{qclass: class, qtype: qtype, zones: zones} + + t.regex = make([]*regexp.Regexp, 0) + templatePrefix := "" + + t.answer = make([]*gotmpl.Template, 0) + t.upstream = upstream.New() + + for c.NextBlock() { + switch c.Val() { + case "match": + args := c.RemainingArgs() + if len(args) == 0 { + return handler, c.ArgErr() + } + for _, regex := range args { + r, err := regexp.Compile(regex) + if err != nil { + return handler, c.Errf("could not parse regex: %s, %v", regex, err) + } + templatePrefix = templatePrefix + regex + " " + t.regex = append(t.regex, r) + } + + case "answer": + args := c.RemainingArgs() + if len(args) == 0 { + return handler, c.ArgErr() + } + for _, answer := range args { + tmpl, err := newTemplate("answer", answer) + if err != nil { + return handler, c.Errf("could not compile template: %s, %v", c.Val(), err) + } + t.answer = append(t.answer, tmpl) + } + + case "additional": + args := c.RemainingArgs() + if len(args) == 0 { + return handler, c.ArgErr() + } + for _, additional := range args { + tmpl, err := newTemplate("additional", additional) + if err != nil { + return handler, c.Errf("could not compile template: %s, %v\n", c.Val(), err) + } + t.additional = append(t.additional, tmpl) + } + + case "authority": + args := c.RemainingArgs() + if len(args) == 0 { + return handler, c.ArgErr() + } + for _, authority := range args { + tmpl, err := newTemplate("authority", authority) + if err != nil { + return handler, c.Errf("could not compile template: %s, %v\n", c.Val(), err) + } + t.authority = append(t.authority, tmpl) + } + + case "rcode": + if !c.NextArg() { + return handler, c.ArgErr() + } + rcode, ok := dns.StringToRcode[c.Val()] + if !ok { + return handler, c.Errf("unknown rcode %s", c.Val()) + } + t.rcode = rcode + + case "ederror": + args := c.RemainingArgs() + if len(args) != 1 && len(args) != 2 { + return handler, c.ArgErr() + } + + code, err := strconv.ParseUint(args[0], 10, 16) + if err != nil { + return handler, c.Errf("error parsing extended DNS error code %s, %v\n", c.Val(), err) + } + if len(args) == 2 { + t.ederror = &ederror{code: uint16(code), reason: args[1]} + } else { + t.ederror = &ederror{code: uint16(code)} + } + + case "fallthrough": + t.fall.SetZonesFromArgs(c.RemainingArgs()) + + case "upstream": + // remove soon + c.RemainingArgs() + default: + return handler, c.ArgErr() + } + } + + if len(t.regex) == 0 { + t.regex = append(t.regex, regexp.MustCompile(".*")) + } + + handler.Templates = append(handler.Templates, t) + } + + return +} diff --git a/plugin/template/setup_test.go b/plugin/template/setup_test.go new file mode 100644 index 0000000..345525d --- /dev/null +++ b/plugin/template/setup_test.go @@ -0,0 +1,200 @@ +package template + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", `template ANY ANY { + rcode + }`) + err := setupTemplate(c) + if err == nil { + t.Errorf("Expected setupTemplate to fail on broken template, got no error") + } + c = caddy.NewTestController("dns", `template ANY ANY { + rcode NXDOMAIN + }`) + err = setupTemplate(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } +} + +func TestSetupParse(t *testing.T) { + serverBlockKeys := []string{"domain.com.:8053", "dynamic.domain.com.:8053"} + + tests := []struct { + inputFileRules string + shouldErr bool + }{ + // parse errors + {`template`, true}, + {`template X`, true}, + {`template ANY`, true}, + {`template ANY X`, true}, + { + `template ANY ANY .* { + notavailable + }`, + true, + }, + { + `template ANY ANY { + answer + }`, + true, + }, + { + `template ANY ANY { + additional + }`, + true, + }, + { + `template ANY ANY { + rcode + }`, + true, + }, + { + `template ANY ANY { + rcode UNDEFINED + }`, + true, + }, + { + `template ANY ANY { + answer "{{" + }`, + true, + }, + { + `template ANY ANY { + additional "{{" + }`, + true, + }, + { + `template ANY ANY { + authority "{{" + }`, + true, + }, + { + `template ANY ANY { + answer "{{ notAFunction }}" + }`, + true, + }, + { + `template ANY ANY { + answer "{{ parseInt }}" + additional "{{ parseInt }}" + authority "{{ parseInt }}" + }`, + false, + }, + // examples + {`template ANY ANY (?P<x>`, false}, + { + `template ANY ANY { + + }`, + false, + }, + { + `template ANY A example.com { + match ip-(?P<a>[0-9]*)-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]com + answer "{{ .Name }} A {{ .Group.a }}.{{ .Group.b }}.{{ .Group.c }}.{{ .Grup.d }}." + fallthrough + }`, + false, + }, + { + `template ANY AAAA example.com { + match ip-(?P<a>[0-9]*)-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]com + authority "example.com 60 IN SOA ns.example.com hostmaster.example.com (1 60 60 60 60)" + fallthrough + }`, + false, + }, + { + `template IN ANY example.com { + match "[.](example[.]com[.]dc1[.]example[.]com[.])$" + rcode NXDOMAIN + authority "{{ index .Match 1 }} 60 IN SOA ns.{{ index .Match 1 }} hostmaster.example.com (1 60 60 60 60)" + fallthrough example.com + }`, + false, + }, + { + `template IN A example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + } + template IN MX example. { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN MX 10 {{ .Name }}" + additional "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + }`, + false, + }, + { + `template IN A example { + match ^ip0a(?P<b>[a-f0-9]{2})(?P<c>[a-f0-9]{2})(?P<d>[a-f0-9]{2})[.]example[.]$ + answer "{{ .Name }} 3600 IN A 10.{{ parseInt .Group.b 16 8 }}.{{ parseInt .Group.c 16 8 }}.{{ parseInt .Group.d 16 8 }}" + }`, + false, + }, + { + `template IN MX example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN MX 10 {{ .Name }}" + additional "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + authority "example. 60 IN NS ns0.example." + authority "example. 60 IN NS ns1.example." + additional "ns0.example. 60 IN A 203.0.113.8" + additional "ns1.example. 60 IN A 198.51.100.8" + }`, + false, + }, + { + `template ANY ANY invalid { + rcode NXDOMAIN + authority "invalid. 60 {{ .Class }} SOA ns.invalid. hostmaster.invalid. (1 60 60 60 60)" + ederror 21 "Blocked according to RFC2606" + }`, + false, + }, + { + `template ANY ANY invalid { + rcode NXDOMAIN + authority "invalid. 60 {{ .Class }} SOA ns.invalid. hostmaster.invalid. (1 60 60 60 60)" + ederror invalid "Blocked according to RFC2606" + }`, + true, + }, + { + `template ANY ANY invalid { + rcode NXDOMAIN + authority "invalid. 60 {{ .Class }} SOA ns.invalid. hostmaster.invalid. (1 60 60 60 60)" + ederror too many arguments + }`, + true, + }, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.inputFileRules) + c.ServerBlockKeys = serverBlockKeys + templates, err := templateParse(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error\n---\n%s\n---\n%v", i, test.inputFileRules, templates) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } + } +} diff --git a/plugin/template/template.go b/plugin/template/template.go new file mode 100644 index 0000000..5eac81e --- /dev/null +++ b/plugin/template/template.go @@ -0,0 +1,227 @@ +package template + +import ( + "bytes" + "context" + "regexp" + "strconv" + gotmpl "text/template" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/metrics" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// Handler is a plugin handler that takes a query and templates a response. +type Handler struct { + Zones []string + + Next plugin.Handler + Templates []template +} + +type template struct { + zones []string + rcode int + regex []*regexp.Regexp + answer []*gotmpl.Template + additional []*gotmpl.Template + authority []*gotmpl.Template + qclass uint16 + qtype uint16 + ederror *ederror + fall fall.F + upstream Upstreamer +} + +type ederror struct { + code uint16 + reason string +} + +// Upstreamer looks up targets of CNAME templates +type Upstreamer interface { + Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error) +} + +type templateData struct { + Zone string + Name string + Regex string + Match []string + Group map[string]string + Class string + Type string + Message *dns.Msg + Question *dns.Question + Remote string + md map[string]metadata.Func +} + +func (data *templateData) Meta(metaName string) string { + if data.md == nil { + return "" + } + + if f, ok := data.md[metaName]; ok { + return f() + } + + return "" +} + +// ServeDNS implements the plugin.Handler interface. +func (h Handler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + zone := plugin.Zones(h.Zones).Matches(state.Name()) + if zone == "" { + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) + } + + for _, template := range h.Templates { + data, match, fthrough := template.match(ctx, state) + if !match { + if !fthrough { + return dns.RcodeServerFailure, nil + } + continue + } + + templateMatchesCount.WithLabelValues(metrics.WithServer(ctx), data.Zone, metrics.WithView(ctx), data.Class, data.Type).Inc() + + if template.rcode == dns.RcodeServerFailure { + return template.rcode, nil + } + + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + msg.Rcode = template.rcode + + for _, answer := range template.answer { + rr, err := executeRRTemplate(metrics.WithServer(ctx), metrics.WithView(ctx), "answer", answer, data) + if err != nil { + return dns.RcodeServerFailure, err + } + msg.Answer = append(msg.Answer, rr) + if template.upstream != nil && (state.QType() == dns.TypeA || state.QType() == dns.TypeAAAA) && rr.Header().Rrtype == dns.TypeCNAME { + if up, err := template.upstream.Lookup(ctx, state, rr.(*dns.CNAME).Target, state.QType()); err == nil && up != nil { + msg.Truncated = up.Truncated + msg.Answer = append(msg.Answer, up.Answer...) + } + } + } + for _, additional := range template.additional { + rr, err := executeRRTemplate(metrics.WithServer(ctx), metrics.WithView(ctx), "additional", additional, data) + if err != nil { + return dns.RcodeServerFailure, err + } + msg.Extra = append(msg.Extra, rr) + } + for _, authority := range template.authority { + rr, err := executeRRTemplate(metrics.WithServer(ctx), metrics.WithView(ctx), "authority", authority, data) + if err != nil { + return dns.RcodeServerFailure, err + } + msg.Ns = append(msg.Ns, rr) + } + + if template.ederror != nil { + msg = msg.SetEdns0(4096, true) + ede := dns.EDNS0_EDE{InfoCode: template.ederror.code, ExtraText: template.ederror.reason} + msg.IsEdns0().Option = append(msg.IsEdns0().Option, &ede) + } + + w.WriteMsg(msg) + return template.rcode, nil + } + + return plugin.NextOrFailure(h.Name(), h.Next, ctx, w, r) +} + +// Name implements the plugin.Handler interface. +func (h Handler) Name() string { return "template" } + +func executeRRTemplate(server, view, section string, template *gotmpl.Template, data *templateData) (dns.RR, error) { + buffer := &bytes.Buffer{} + err := template.Execute(buffer, data) + if err != nil { + templateFailureCount.WithLabelValues(server, data.Zone, view, data.Class, data.Type, section, template.Tree.Root.String()).Inc() + return nil, err + } + rr, err := dns.NewRR(buffer.String()) + if err != nil { + templateRRFailureCount.WithLabelValues(server, data.Zone, view, data.Class, data.Type, section, template.Tree.Root.String()).Inc() + return rr, err + } + return rr, nil +} + +func newTemplate(name, text string) (*gotmpl.Template, error) { + funcMap := gotmpl.FuncMap{ + "parseInt": strconv.ParseUint, + } + return gotmpl.New(name).Funcs(funcMap).Parse(text) +} + +func (t template) match(ctx context.Context, state request.Request) (*templateData, bool, bool) { + q := state.Req.Question[0] + data := &templateData{md: metadata.ValueFuncs(ctx), Remote: state.IP()} + + zone := plugin.Zones(t.zones).Matches(state.Name()) + if zone == "" { + return data, false, true + } + + if t.qclass != dns.ClassANY && q.Qclass != dns.ClassANY && q.Qclass != t.qclass { + return data, false, true + } + if t.qtype != dns.TypeANY && q.Qtype != dns.TypeANY && q.Qtype != t.qtype { + return data, false, true + } + + for _, regex := range t.regex { + if !regex.MatchString(state.Name()) { + continue + } + + data.Zone = zone + data.Regex = regex.String() + data.Name = state.Name() + data.Question = &q + data.Message = state.Req + if q.Qclass != dns.ClassANY { + data.Class = dns.ClassToString[q.Qclass] + } else { + data.Class = dns.ClassToString[t.qclass] + } + if q.Qtype != dns.TypeANY { + data.Type = dns.TypeToString[q.Qtype] + } else { + data.Type = dns.TypeToString[t.qtype] + } + + matches := regex.FindStringSubmatch(state.Name()) + data.Match = make([]string, len(matches)) + data.Group = make(map[string]string) + groupNames := regex.SubexpNames() + for i, m := range matches { + data.Match[i] = m + data.Group[strconv.Itoa(i)] = m + } + for i, m := range matches { + if len(groupNames[i]) > 0 { + data.Group[groupNames[i]] = m + } + } + + return data, true, false + } + + return data, false, t.fall.Through(state.Name()) +} diff --git a/plugin/template/template_test.go b/plugin/template/template_test.go new file mode 100644 index 0000000..4c16098 --- /dev/null +++ b/plugin/template/template_test.go @@ -0,0 +1,679 @@ +package template + +import ( + "context" + "fmt" + "regexp" + "testing" + gotmpl "text/template" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fall" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestHandler(t *testing.T) { + exampleDomainATemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("(^|[.])ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + exampleDomainAParseIntTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("^ip0a(?P<b>[a-f0-9]{2})(?P<c>[a-f0-9]{2})(?P<d>[a-f0-9]{2})[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 IN A 10.{{ parseInt .Group.b 16 8 }}.{{ parseInt .Group.c 16 8 }}.{{ parseInt .Group.d 16 8 }}"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + exampleDomainIPATemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile(".*")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 IN A {{ .Remote }}"))}, + qclass: dns.ClassINET, + qtype: dns.TypeA, + fall: fall.Root, + zones: []string{"."}, + } + exampleDomainANSTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("(^|[.])ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}"))}, + additional: []*gotmpl.Template{gotmpl.Must(newTemplate("additional", "ns0.example. IN A 203.0.113.8"))}, + authority: []*gotmpl.Template{gotmpl.Must(newTemplate("authority", "example. IN NS ns0.example.com."))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + exampleDomainMXTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("(^|[.])ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 MX 10 {{ .Name }}"))}, + additional: []*gotmpl.Template{gotmpl.Must(newTemplate("additional", "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + invalidDomainTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("[.]invalid[.]$")}, + rcode: dns.RcodeNameError, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "invalid. 60 {{ .Class }} SOA a.invalid. b.invalid. (1 60 60 60 60)"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + rcodeServfailTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile(".*")}, + rcode: dns.RcodeServerFailure, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + brokenTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 IN TXT \"{{ index .Match 2 }}\""))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + brokenParseIntTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }} 60 IN TXT \"{{ parseInt \"gg\" 16 8 }}\""))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + nonRRTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }}"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + nonRRAdditionalTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("[.]example[.]$")}, + additional: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }}"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + nonRRAuthoritativeTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("[.]example[.]$")}, + authority: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "{{ .Name }}"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + cnameTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("example[.]net[.]")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", "example.net 60 IN CNAME target.example.com"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + mdTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("(^|[.])ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", `{{ .Meta "foo" }}-{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}`))}, + additional: []*gotmpl.Template{gotmpl.Must(newTemplate("additional", `{{ .Meta "bar" }}.example. IN A 203.0.113.8`))}, + authority: []*gotmpl.Template{gotmpl.Must(newTemplate("authority", `example. IN NS {{ .Meta "bar" }}.example.com.`))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + mdMissingTemplate := template{ + regex: []*regexp.Regexp{regexp.MustCompile("(^|[.])ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$")}, + answer: []*gotmpl.Template{gotmpl.Must(newTemplate("answer", `{{ .Meta "foofoo" }}{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}`))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + } + templateWithEDE := template{ + rcode: dns.RcodeNameError, + regex: []*regexp.Regexp{regexp.MustCompile(".*")}, + authority: []*gotmpl.Template{gotmpl.Must(newTemplate("authority", "invalid. 60 {{ .Class }} SOA ns.invalid. hostmaster.invalid. (1 60 60 60 60)"))}, + qclass: dns.ClassANY, + qtype: dns.TypeANY, + fall: fall.Root, + zones: []string{"."}, + ederror: &ederror{code: 21, reason: "Blocked due to RFC2606"}, + } + + tests := []struct { + tmpl template + qname string + name string + qclass uint16 + qtype uint16 + expectedCode int + expectedErr string + verifyResponse func(*dns.Msg) error + md map[string]string + }{ + { + name: "RcodeServFail", + tmpl: rcodeServfailTemplate, + qname: "test.invalid.", + expectedCode: dns.RcodeServerFailure, + verifyResponse: func(r *dns.Msg) error { + return nil + }, + }, + { + name: "ExampleDomainNameMismatch", + tmpl: exampleDomainATemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "test.invalid.", + expectedCode: rcodeFallthrough, + }, + { + name: "BrokenTemplate", + tmpl: brokenTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeANY, + qname: "test.example.", + expectedCode: dns.RcodeServerFailure, + expectedErr: `template: answer:1:26: executing "answer" at <index .Match 2>: error calling index: index out of range: 2`, + verifyResponse: func(r *dns.Msg) error { + return nil + }, + }, + { + name: "NonRRTemplate", + tmpl: nonRRTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeANY, + qname: "test.example.", + expectedCode: dns.RcodeServerFailure, + expectedErr: `dns: not a TTL: "test.example." at line: 1:13`, + verifyResponse: func(r *dns.Msg) error { + return nil + }, + }, + { + name: "NonRRAdditionalTemplate", + tmpl: nonRRAdditionalTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeANY, + qname: "test.example.", + expectedCode: dns.RcodeServerFailure, + expectedErr: `dns: not a TTL: "test.example." at line: 1:13`, + verifyResponse: func(r *dns.Msg) error { + return nil + }, + }, + { + name: "NonRRAuthorityTemplate", + tmpl: nonRRAuthoritativeTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeANY, + qname: "test.example.", + expectedCode: dns.RcodeServerFailure, + expectedErr: `dns: not a TTL: "test.example." at line: 1:13`, + verifyResponse: func(r *dns.Msg) error { + return nil + }, + }, + { + name: "ExampleIPMatch", + tmpl: exampleDomainIPATemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "test.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + if r.Answer[0].(*dns.A).A.String() != "10.240.0.1" { + return fmt.Errorf("expected an A record for 10.95.12.8, got %v", r.Answer[0].String()) + } + return nil + }, + }, + { + name: "ExampleDomainMatch", + tmpl: exampleDomainATemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "ip-10-95-12-8.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + if r.Answer[0].(*dns.A).A.String() != "10.95.12.8" { + return fmt.Errorf("expected an A record for 10.95.12.8, got %v", r.Answer[0].String()) + } + return nil + }, + }, + { + name: "ExampleDomainMatchHexIp", + tmpl: exampleDomainAParseIntTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "ip0a5f0c09.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + if r.Answer[0].(*dns.A).A.String() != "10.95.12.9" { + return fmt.Errorf("expected an A record for 10.95.12.9, got %v", r.Answer[0].String()) + } + return nil + }, + }, + { + name: "BrokenParseIntTemplate", + tmpl: brokenParseIntTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeANY, + qname: "test.example.", + expectedCode: dns.RcodeServerFailure, + expectedErr: "template: answer:1:26: executing \"answer\" at <parseInt \"gg\" 16 8>: error calling parseInt: strconv.ParseUint: parsing \"gg\": invalid syntax", + verifyResponse: func(r *dns.Msg) error { + return nil + }, + }, + { + name: "ExampleDomainMXMatch", + tmpl: exampleDomainMXTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeMX, + qname: "ip-10-95-12-8.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeMX { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + if len(r.Extra) != 1 { + return fmt.Errorf("expected 1 extra record, got %v", len(r.Extra)) + } + if r.Extra[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an additional A record, got %v", dns.TypeToString[r.Extra[0].Header().Rrtype]) + } + return nil + }, + }, + { + name: "ExampleDomainANSMatch", + tmpl: exampleDomainANSTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "ip-10-95-12-8.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + if len(r.Extra) != 1 { + return fmt.Errorf("expected 1 extra record, got %v", len(r.Extra)) + } + if r.Extra[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an additional A record, got %v", dns.TypeToString[r.Extra[0].Header().Rrtype]) + } + if len(r.Ns) != 1 { + return fmt.Errorf("expected 1 authoritative record, got %v", len(r.Extra)) + } + if r.Ns[0].Header().Rrtype != dns.TypeNS { + return fmt.Errorf("expected an authoritative NS record, got %v", dns.TypeToString[r.Extra[0].Header().Rrtype]) + } + return nil + }, + }, + { + name: "ExampleInvalidNXDOMAIN", + tmpl: invalidDomainTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeMX, + qname: "test.invalid.", + expectedCode: dns.RcodeNameError, + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeSOA { + return fmt.Errorf("expected an SOA record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + return nil + }, + }, + { + name: "CNAMEWithoutUpstream", + tmpl: cnameTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "example.net.", + expectedCode: dns.RcodeSuccess, + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + return nil + }, + }, + { + name: "mdMatch", + tmpl: mdTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "ip-10-95-12-8.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + name := "myfoo-ip-10-95-12-8.example." + if r.Answer[0].Header().Name != name { + return fmt.Errorf("expected answer name %q, got %q", name, r.Answer[0].Header().Name) + } + if len(r.Extra) != 1 { + return fmt.Errorf("expected 1 extra record, got %v", len(r.Extra)) + } + if r.Extra[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an additional A record, got %v", dns.TypeToString[r.Extra[0].Header().Rrtype]) + } + name = "mybar.example." + if r.Extra[0].Header().Name != name { + return fmt.Errorf("expected additional name %q, got %q", name, r.Extra[0].Header().Name) + } + if len(r.Ns) != 1 { + return fmt.Errorf("expected 1 authoritative record, got %v", len(r.Extra)) + } + if r.Ns[0].Header().Rrtype != dns.TypeNS { + return fmt.Errorf("expected an authoritative NS record, got %v", dns.TypeToString[r.Extra[0].Header().Rrtype]) + } + ns, ok := r.Ns[0].(*dns.NS) + if !ok { + return fmt.Errorf("expected NS record to be type NS, got %v", r.Ns[0]) + } + rdata := "mybar.example.com." + if ns.Ns != rdata { + return fmt.Errorf("expected ns rdata %q, got %q", rdata, ns.Ns) + } + return nil + }, + md: map[string]string{ + "foo": "myfoo", + "bar": "mybar", + "foobar": "myfoobar", + }, + }, + { + name: "mdMissing", + tmpl: mdMissingTemplate, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "ip-10-95-12-8.example.", + verifyResponse: func(r *dns.Msg) error { + if len(r.Answer) != 1 { + return fmt.Errorf("expected 1 answer, got %v", len(r.Answer)) + } + if r.Answer[0].Header().Rrtype != dns.TypeA { + return fmt.Errorf("expected an A record answer, got %v", dns.TypeToString[r.Answer[0].Header().Rrtype]) + } + name := "ip-10-95-12-8.example." + if r.Answer[0].Header().Name != name { + return fmt.Errorf("expected answer name %q, got %q", name, r.Answer[0].Header().Name) + } + return nil + }, + md: map[string]string{ + "foo": "myfoo", + }, + }, + { + name: "EDNS error", + tmpl: templateWithEDE, + qclass: dns.ClassINET, + qtype: dns.TypeA, + qname: "test.invalid.", + expectedCode: dns.RcodeNameError, + verifyResponse: func(r *dns.Msg) error { + if opt := r.IsEdns0(); opt != nil { + matched := false + for _, ednsopt := range opt.Option { + if ede, ok := ednsopt.(*dns.EDNS0_EDE); ok { + if ede.InfoCode != dns.ExtendedErrorCodeNotSupported { + return fmt.Errorf("unexpected EDE code = %v, want %v", ede.InfoCode, dns.ExtendedErrorCodeNotSupported) + } + matched = true + } + } + if !matched { + t.Error("Error: acl.ServeDNS() missing Extended DNS Error option") + } + } else { + return fmt.Errorf("expected EDNS enabled") + } + return nil + }, + }, + } + + ctx := context.TODO() + + for _, tr := range tests { + handler := Handler{ + Next: test.NextHandler(rcodeFallthrough, nil), + Zones: []string{"."}, + Templates: []template{tr.tmpl}, + } + req := &dns.Msg{ + Question: []dns.Question{{ + Name: tr.qname, + Qclass: tr.qclass, + Qtype: tr.qtype, + }}, + } + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + if tr.md != nil { + ctx = metadata.ContextWithMetadata(context.Background()) + + for k, v := range tr.md { + // Go requires copying to a local variable for the closure to work + kk := k + vv := v + metadata.SetValueFunc(ctx, kk, func() string { + return vv + }) + } + } + + code, err := handler.ServeDNS(ctx, rec, req) + if err == nil && tr.expectedErr != "" { + t.Errorf("Test %v expected error: %v, got nothing", tr.name, tr.expectedErr) + } + if err != nil && tr.expectedErr == "" { + t.Errorf("Test %v expected no error got: %v", tr.name, err) + } + if err != nil && tr.expectedErr != "" && err.Error() != tr.expectedErr { + t.Errorf("Test %v expected error: %v, got: %v", tr.name, tr.expectedErr, err) + } + if code != tr.expectedCode { + t.Errorf("Test %v expected response code %v, got %v", tr.name, tr.expectedCode, code) + } + if err == nil && code != rcodeFallthrough { + // only verify if we got no error and expected no error + if err := tr.verifyResponse(rec.Msg); err != nil { + t.Errorf("Test %v could not verify the response: %v", tr.name, err) + } + } + } +} + +// TestMultiSection verifies that a corefile with multiple but different template sections works +func TestMultiSection(t *testing.T) { + ctx := context.TODO() + + multisectionConfig := ` + # Implicit section (see c.ServerBlockKeys) + # test.:8053 { + + # REFUSE IN A for the server zone (test.) + template IN A { + rcode REFUSED + } + # Fallthrough everything IN TXT for test. + template IN TXT { + match "$^" + rcode SERVFAIL + fallthrough + } + # Answer CH TXT *.coredns.invalid. / coredns.invalid. + template CH TXT coredns.invalid { + answer "{{ .Name }} 60 CH TXT \"test\"" + } + # Answer example. ip templates and fallthrough otherwise + template IN A example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + fallthrough + } + # Answer MX record requests for ip templates in example. and never fall through + template IN MX example { + match ^ip-10-(?P<b>[0-9]*)-(?P<c>[0-9]*)-(?P<d>[0-9]*)[.]example[.]$ + answer "{{ .Name }} 60 IN MX 10 {{ .Name }}" + additional "{{ .Name }} 60 IN A 10.{{ .Group.b }}.{{ .Group.c }}.{{ .Group.d }}" + } + ` + c := caddy.NewTestController("dns", multisectionConfig) + c.ServerBlockKeys = []string{"test.:8053"} + + handler, err := templateParse(c) + if err != nil { + t.Fatalf("TestMultiSection could not parse config: %v", err) + } + + handler.Next = test.NextHandler(rcodeFallthrough, nil) + + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + // Asking for test. IN A -> REFUSED + + req := &dns.Msg{Question: []dns.Question{{Name: "some.test.", Qclass: dns.ClassINET, Qtype: dns.TypeA}}} + code, err := handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving some.test. A, got: %v", err) + } + if code != dns.RcodeRefused { + t.Fatalf("TestMultiSection expected response code REFUSED got: %v", code) + } + + // Asking for test. IN TXT -> fallthrough + + req = &dns.Msg{Question: []dns.Question{{Name: "some.test.", Qclass: dns.ClassINET, Qtype: dns.TypeTXT}}} + code, err = handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving some.test. TXT, got: %v", err) + } + if code != rcodeFallthrough { + t.Fatalf("TestMultiSection expected response code fallthrough got: %v", code) + } + + // Asking for coredns.invalid. CH TXT -> TXT "test" + + req = &dns.Msg{Question: []dns.Question{{Name: "coredns.invalid.", Qclass: dns.ClassCHAOS, Qtype: dns.TypeTXT}}} + code, err = handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving coredns.invalid. TXT, got: %v", err) + } + if code != dns.RcodeSuccess { + t.Fatalf("TestMultiSection expected success response for coredns.invalid. TXT got: %v", code) + } + if len(rec.Msg.Answer) != 1 { + t.Fatalf("TestMultiSection expected one answer for coredns.invalid. TXT got: %v", rec.Msg.Answer) + } + if rec.Msg.Answer[0].Header().Rrtype != dns.TypeTXT || rec.Msg.Answer[0].(*dns.TXT).Txt[0] != "test" { + t.Fatalf("TestMultiSection a \"test\" answer for coredns.invalid. TXT got: %v", rec.Msg.Answer[0]) + } + + // Asking for an ip template in example + + req = &dns.Msg{Question: []dns.Question{{Name: "ip-10-11-12-13.example.", Qclass: dns.ClassINET, Qtype: dns.TypeA}}} + code, err = handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving ip-10-11-12-13.example. IN A, got: %v", err) + } + if code != dns.RcodeSuccess { + t.Fatalf("TestMultiSection expected success response ip-10-11-12-13.example. IN A got: %v, %v", code, dns.RcodeToString[code]) + } + if len(rec.Msg.Answer) != 1 { + t.Fatalf("TestMultiSection expected one answer for ip-10-11-12-13.example. IN A got: %v", rec.Msg.Answer) + } + if rec.Msg.Answer[0].Header().Rrtype != dns.TypeA { + t.Fatalf("TestMultiSection an A RR answer for ip-10-11-12-13.example. IN A got: %v", rec.Msg.Answer[0]) + } + + // Asking for an MX ip template in example + + req = &dns.Msg{Question: []dns.Question{{Name: "ip-10-11-12-13.example.", Qclass: dns.ClassINET, Qtype: dns.TypeMX}}} + code, err = handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving ip-10-11-12-13.example. IN MX, got: %v", err) + } + if code != dns.RcodeSuccess { + t.Fatalf("TestMultiSection expected success response ip-10-11-12-13.example. IN MX got: %v, %v", code, dns.RcodeToString[code]) + } + if len(rec.Msg.Answer) != 1 { + t.Fatalf("TestMultiSection expected one answer for ip-10-11-12-13.example. IN MX got: %v", rec.Msg.Answer) + } + if rec.Msg.Answer[0].Header().Rrtype != dns.TypeMX { + t.Fatalf("TestMultiSection an A RR answer for ip-10-11-12-13.example. IN MX got: %v", rec.Msg.Answer[0]) + } + + // Test that something.example. A does fall through but something.example. MX does not + + req = &dns.Msg{Question: []dns.Question{{Name: "something.example.", Qclass: dns.ClassINET, Qtype: dns.TypeA}}} + code, err = handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving something.example. IN A, got: %v", err) + } + if code != rcodeFallthrough { + t.Fatalf("TestMultiSection expected a fall through resolving something.example. IN A, got: %v, %v", code, dns.RcodeToString[code]) + } + + req = &dns.Msg{Question: []dns.Question{{Name: "something.example.", Qclass: dns.ClassINET, Qtype: dns.TypeMX}}} + code, err = handler.ServeDNS(ctx, rec, req) + if err != nil { + t.Fatalf("TestMultiSection expected no error resolving something.example. IN MX, got: %v", err) + } + if code == rcodeFallthrough { + t.Fatalf("TestMultiSection expected no fall through resolving something.example. IN MX") + } + if code != dns.RcodeServerFailure { + t.Fatalf("TestMultiSection expected SERVFAIL resolving something.example. IN MX, got %v, %v", code, dns.RcodeToString[code]) + } +} + +const rcodeFallthrough = 3841 // reserved for private use, used to indicate a fallthrough diff --git a/plugin/test/doc.go b/plugin/test/doc.go new file mode 100644 index 0000000..75281ed --- /dev/null +++ b/plugin/test/doc.go @@ -0,0 +1,2 @@ +// Package test contains helper functions for writing plugin tests. +package test diff --git a/plugin/test/file.go b/plugin/test/file.go new file mode 100644 index 0000000..667b6a3 --- /dev/null +++ b/plugin/test/file.go @@ -0,0 +1,103 @@ +package test + +import ( + "os" + "path/filepath" + "testing" +) + +// TempFile will create a temporary file on disk and returns the name and a cleanup function to remove it later. +func TempFile(dir, content string) (string, func(), error) { + f, err := os.CreateTemp(dir, "go-test-tmpfile") + if err != nil { + return "", nil, err + } + if err := os.WriteFile(f.Name(), []byte(content), 0644); err != nil { + return "", nil, err + } + rmFunc := func() { os.Remove(f.Name()) } + return f.Name(), rmFunc, nil +} + +// WritePEMFiles creates a tmp dir with ca.pem, cert.pem, and key.pem +func WritePEMFiles(t *testing.T) (string, error) { + tempDir := t.TempDir() + + data := `-----BEGIN CERTIFICATE----- +MIIC9zCCAd+gAwIBAgIJALGtqdMzpDemMA0GCSqGSIb3DQEBCwUAMBIxEDAOBgNV +BAMMB2t1YmUtY2EwHhcNMTYxMDE5MTU1NDI0WhcNNDQwMzA2MTU1NDI0WjASMRAw +DgYDVQQDDAdrdWJlLWNhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA +pa4Wu/WkpJNRr8pMVE6jjwzNUOx5mIyoDr8WILSxVQcEeyVPPmAqbmYXtVZO11p9 +jTzoEqF7Kgts3HVYGCk5abqbE14a8Ru/DmV5avU2hJ/NvSjtNi/O+V6SzCbg5yR9 +lBR53uADDlzuJEQT9RHq7A5KitFkx4vUcXnjOQCbDogWFoYuOgNEwJPy0Raz3NJc +ViVfDqSJ0QHg02kCOMxcGFNRQ9F5aoW7QXZXZXD0tn3wLRlu4+GYyqt8fw5iNdLJ +t79yKp8I+vMTmMPz4YKUO+eCl5EY10Qs7wvoG/8QNbjH01BRN3L8iDT2WfxdvjTu +1RjPxFL92i+B7HZO7jGLfQIDAQABo1AwTjAdBgNVHQ4EFgQUZTrg+Xt87tkxDhlB +gKk9FdTOW3IwHwYDVR0jBBgwFoAUZTrg+Xt87tkxDhlBgKk9FdTOW3IwDAYDVR0T +BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEApB7JFVrZpGSOXNO3W7SlN6OCPXv9 +C7rIBc8rwOrzi2mZWcBmWheQrqBo8xHif2rlFNVQxtq3JcQ8kfg/m1fHeQ/Ygzel +Z+U1OqozynDySBZdNn9i+kXXgAUCqDPp3hEQWe0os/RRpIwo9yOloBxdiX6S0NIf +VB8n8kAynFPkH7pYrGrL1HQgDFCSfa4tUJ3+9sppnCu0pNtq5AdhYx9xFb2sn+8G +xGbtCkhVk2VQ+BiCWnjYXJ6ZMzabP7wiOFDP9Pvr2ik22PRItsW/TLfHFXM1jDmc +I1rs/VUGKzcJGVIWbHrgjP68CTStGAvKgbsTqw7aLXTSqtPw88N9XVSyRg== +-----END CERTIFICATE-----` + path := filepath.Join(tempDir, "ca.pem") + if err := os.WriteFile(path, []byte(data), 0644); err != nil { + return "", err + } + data = `-----BEGIN CERTIFICATE----- +MIICozCCAYsCCQCRlf5BrvPuqjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDDAdr +dWJlLWNhMB4XDTE2MTAxOTE2MDUxOFoXDTE3MTAxOTE2MDUxOFowFTETMBEGA1UE +AwwKa3ViZS1hZG1pbjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMTw +a7wCFoiCad/N53aURfjrme+KR7FS0yf5Ur9OR/oM3BoS9stYu5Flzr35oL5T6t5G +c2ey78mUs/Cs07psnjUdKH55bDpJSdG7zW9mXNyeLwIefFcj/38SS5NBSotmLo8u +scJMGXeQpCQtfVuVJSP2bfU5u5d0KTLSg/Cor6UYonqrRB82HbOuuk8Wjaww4VHo +nCq7X8o948V6HN5ZibQOgMMo+nf0wORREHBjvwc4W7ewbaTcfoe1VNAo/QnkqxTF +ueMb2HxgghArqQSK8b44O05V0zrde25dVnmnte6sPjcV0plqMJ37jViISxsOPUFh +/ZW7zbIM/7CMcDekCiECAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAYZE8OxwRR7GR +kdd5aIriDwWfcl56cq5ICyx87U8hAZhBxk46a6a901LZPzt3xKyWIFQSRj/NYiQ+ +/thjGLZI2lhkVgYtyAD4BNxDiuppQSCbkjY9tLVDdExGttEVN7+UYDWJBHy6X16Y +xSG9FE3Dvp9LI89Nq8E3dRh+Q8wu52q9HaQXjS5YtzQOtDFKPBkihXu/c6gEHj4Y +bZVk8rFiH8/CvcQxAuvNI3VVCFUKd2LeQtqwYQQ//qoiuA15krTq5Ut9eXJ8zxAw +zhDEPP4FhY+Sz+y1yWirphl7A1aZwhXVPcfWIGqpQ3jzNwUeocbH27kuLh+U4hQo +qeg10RdFnw== +-----END CERTIFICATE-----` + path = filepath.Join(tempDir, "cert.pem") + if err := os.WriteFile(path, []byte(data), 0644); err != nil { + return "", err + } + + data = `-----BEGIN RSA PRIVATE KEY----- +MIIEpgIBAAKCAQEAxPBrvAIWiIJp383ndpRF+OuZ74pHsVLTJ/lSv05H+gzcGhL2 +y1i7kWXOvfmgvlPq3kZzZ7LvyZSz8KzTumyeNR0ofnlsOklJ0bvNb2Zc3J4vAh58 +VyP/fxJLk0FKi2Yujy6xwkwZd5CkJC19W5UlI/Zt9Tm7l3QpMtKD8KivpRiieqtE +HzYds666TxaNrDDhUeicKrtfyj3jxXoc3lmJtA6Awyj6d/TA5FEQcGO/Bzhbt7Bt +pNx+h7VU0Cj9CeSrFMW54xvYfGCCECupBIrxvjg7TlXTOt17bl1Weae17qw+NxXS +mWownfuNWIhLGw49QWH9lbvNsgz/sIxwN6QKIQIDAQABAoIBAQDCXq9V7ZGjxWMN +OkFaLVkqJg3V91puztoMt+xNV8t+JTcOnOzrIXZuOFbl9PwLHPPP0SSRkm9LOvKl +dU26zv0OWureeKSymia7U2mcqyC3tX+bzc7WinbeSYZBnc0e7AjD1EgpBcaU1TLL +agIxY3A2oD9CKmrVPhZzTIZf/XztqTYjhvs5I2kBeT0imdYGpXkdndRyGX4I5/JQ +fnp3Czj+AW3zX7RvVnXOh4OtIAcfoG9xoNyD5LOSlJkkX0MwTS8pEBeZA+A4nb+C +ivjnOSgXWD+liisI+LpBgBbwYZ/E49x5ghZYrJt8QXSk7Bl/+UOyv6XZAm2mev6j +RLAZtoABAoGBAP2P+1PoKOwsk+d/AmHqyTCUQm0UG18LOLB/5PyWfXs/6caDmdIe +DZWeZWng1jUQLEadmoEw/CBY5+tPfHlzwzMNhT7KwUfIDQCIBoS7dzHYnwrJ3VZh +qYA05cuGHAAHqwb6UWz3y6Pa4AEVSHX6CM83CAi9jdWZ1rdZybWG+qYBAoGBAMbV +FsR/Ft+tK5ALgXGoG83TlmxzZYuZ1SnNje1OSdCQdMFCJB10gwoaRrw1ICzi40Xk +ydJwV1upGz1om9ReDAD1zQM9artmQx6+TVLiVPALuARdZE70+NrA6w3ZvxUgJjdN +ngvXUr+8SdvaYUAwFu7BulfJlwXjUS711hHW/KQhAoGBALY41QuV2mLwHlLNie7I +hlGtGpe9TXZeYB0nrG6B0CfU5LJPPSotguG1dXhDpm138/nDpZeWlnrAqdsHwpKd +yPhVjR51I7XsZLuvBdA50Q03egSM0c4UXXXPjh1XgaPb3uMi3YWMBwL4ducQXoS6 +bb5M9C8j2lxZNF+L3VPhbxwBAoGBAIEWDvX7XKpTDxkxnxRfA84ZNGusb5y2fsHp +Bd+vGBUj8+kUO8Yzwm9op8vA4ebCVrMl2jGZZd3IaDryE1lIxZpJ+pPD5+tKdQEc +o67P6jz+HrYWu+zW9klvPit71qasfKMi7Rza6oo4f+sQWFsH3ZucgpJD+pyD/Ez0 +pcpnPRaBAoGBANT/xgHBfIWt4U2rtmRLIIiZxKr+3mGnQdpA1J2BCh+/6AvrEx// +E/WObVJXDnBdViu0L9abE9iaTToBVri4cmlDlZagLuKVR+TFTCN/DSlVZTDkqkLI +8chzqtkH6b2b2R73hyRysWjsomys34ma3mEEPTX/aXeAF2MSZ/EWT9yL +-----END RSA PRIVATE KEY-----` + path = filepath.Join(tempDir, "key.pem") + if err := os.WriteFile(path, []byte(data), 0644); err != nil { + return "", err + } + + return tempDir, nil +} diff --git a/plugin/test/file_test.go b/plugin/test/file_test.go new file mode 100644 index 0000000..b225ace --- /dev/null +++ b/plugin/test/file_test.go @@ -0,0 +1,11 @@ +package test + +import "testing" + +func TestTempFile(t *testing.T) { + _, f, e := TempFile(".", "test") + if e != nil { + t.Fatalf("Failed to create temp file: %s", e) + } + defer f() +} diff --git a/plugin/test/helpers.go b/plugin/test/helpers.go new file mode 100644 index 0000000..f99790a --- /dev/null +++ b/plugin/test/helpers.go @@ -0,0 +1,333 @@ +package test + +import ( + "context" + "fmt" + "sort" + + "github.com/miekg/dns" +) + +type sect int + +const ( + // Answer is the answer section in an Msg. + Answer sect = iota + // Ns is the authoritative section in an Msg. + Ns + // Extra is the additional section in an Msg. + Extra +) + +// RRSet represents a list of RRs. +type RRSet []dns.RR + +func (p RRSet) Len() int { return len(p) } +func (p RRSet) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p RRSet) Less(i, j int) bool { return p[i].String() < p[j].String() } + +// Case represents a test case that encapsulates various data from a query and response. +// Note that is the TTL of a record is 303 we don't compare it with the TTL. +type Case struct { + Qname string + Qtype uint16 + Rcode int + Do bool + CheckingDisabled bool + RecursionAvailable bool + AuthenticatedData bool + Authoritative bool + Truncated bool + Answer []dns.RR + Ns []dns.RR + Extra []dns.RR + Error error +} + +// Msg returns a *dns.Msg embedded in c. +func (c Case) Msg() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(c.Qname), c.Qtype) + if c.Do { + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetDo() + o.SetUDPSize(4096) + m.Extra = []dns.RR{o} + } + return m +} + +// A returns an A record from rr. It panics on errors. +func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } + +// AAAA returns an AAAA record from rr. It panics on errors. +func AAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) } + +// CNAME returns a CNAME record from rr. It panics on errors. +func CNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) } + +// DNAME returns a DNAME record from rr. It panics on errors. +func DNAME(rr string) *dns.DNAME { r, _ := dns.NewRR(rr); return r.(*dns.DNAME) } + +// SRV returns a SRV record from rr. It panics on errors. +func SRV(rr string) *dns.SRV { r, _ := dns.NewRR(rr); return r.(*dns.SRV) } + +// SOA returns a SOA record from rr. It panics on errors. +func SOA(rr string) *dns.SOA { r, _ := dns.NewRR(rr); return r.(*dns.SOA) } + +// NS returns an NS record from rr. It panics on errors. +func NS(rr string) *dns.NS { r, _ := dns.NewRR(rr); return r.(*dns.NS) } + +// PTR returns a PTR record from rr. It panics on errors. +func PTR(rr string) *dns.PTR { r, _ := dns.NewRR(rr); return r.(*dns.PTR) } + +// TXT returns a TXT record from rr. It panics on errors. +func TXT(rr string) *dns.TXT { r, _ := dns.NewRR(rr); return r.(*dns.TXT) } + +// CAA returns a CAA record from rr. It panics on errors. +func CAA(rr string) *dns.CAA { r, _ := dns.NewRR(rr); return r.(*dns.CAA) } + +// HINFO returns a HINFO record from rr. It panics on errors. +func HINFO(rr string) *dns.HINFO { r, _ := dns.NewRR(rr); return r.(*dns.HINFO) } + +// MX returns an MX record from rr. It panics on errors. +func MX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) } + +// RRSIG returns an RRSIG record from rr. It panics on errors. +func RRSIG(rr string) *dns.RRSIG { r, _ := dns.NewRR(rr); return r.(*dns.RRSIG) } + +// NSEC returns an NSEC record from rr. It panics on errors. +func NSEC(rr string) *dns.NSEC { r, _ := dns.NewRR(rr); return r.(*dns.NSEC) } + +// DNSKEY returns a DNSKEY record from rr. It panics on errors. +func DNSKEY(rr string) *dns.DNSKEY { r, _ := dns.NewRR(rr); return r.(*dns.DNSKEY) } + +// DS returns a DS record from rr. It panics on errors. +func DS(rr string) *dns.DS { r, _ := dns.NewRR(rr); return r.(*dns.DS) } + +// NAPTR returns a NAPTR record from rr. It panics on errors. +func NAPTR(rr string) *dns.NAPTR { r, _ := dns.NewRR(rr); return r.(*dns.NAPTR) } + +// OPT returns an OPT record with UDP buffer size set to bufsize and the DO bit set to do. +func OPT(bufsize int, do bool) *dns.OPT { + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetVersion(0) + o.SetUDPSize(uint16(bufsize)) + if do { + o.SetDo() + } + return o +} + +// Header tests if the header in resp matches the header as defined in tc. +func Header(tc Case, resp *dns.Msg) error { + if resp.Rcode != tc.Rcode { + return fmt.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode]) + } + + if len(resp.Answer) != len(tc.Answer) { + return fmt.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer)) + } + if len(resp.Ns) != len(tc.Ns) { + return fmt.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns)) + } + if len(resp.Extra) != len(tc.Extra) { + return fmt.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra)) + } + return nil +} + +// Section tests if the section in tc matches rr. +func Section(tc Case, sec sect, rr []dns.RR) error { + section := []dns.RR{} + switch sec { + case 0: + section = tc.Answer + case 1: + section = tc.Ns + case 2: + section = tc.Extra + } + + for i, a := range rr { + if a.Header().Name != section[i].Header().Name { + return fmt.Errorf("RR %d should have a Header Name of %q, but has %q", i, section[i].Header().Name, a.Header().Name) + } + // 303 signals: don't care what the ttl is. + if section[i].Header().Ttl != 303 && a.Header().Ttl != section[i].Header().Ttl { + if _, ok := section[i].(*dns.OPT); !ok { + // we check edns0 bufize on this one + return fmt.Errorf("RR %d should have a Header TTL of %d, but has %d", i, section[i].Header().Ttl, a.Header().Ttl) + } + } + if a.Header().Rrtype != section[i].Header().Rrtype { + return fmt.Errorf("RR %d should have a header rr type of %d, but has %d", i, section[i].Header().Rrtype, a.Header().Rrtype) + } + + switch x := a.(type) { + case *dns.SRV: + if x.Priority != section[i].(*dns.SRV).Priority { + return fmt.Errorf("RR %d should have a Priority of %d, but has %d", i, section[i].(*dns.SRV).Priority, x.Priority) + } + if x.Weight != section[i].(*dns.SRV).Weight { + return fmt.Errorf("RR %d should have a Weight of %d, but has %d", i, section[i].(*dns.SRV).Weight, x.Weight) + } + if x.Port != section[i].(*dns.SRV).Port { + return fmt.Errorf("RR %d should have a Port of %d, but has %d", i, section[i].(*dns.SRV).Port, x.Port) + } + if x.Target != section[i].(*dns.SRV).Target { + return fmt.Errorf("RR %d should have a Target of %q, but has %q", i, section[i].(*dns.SRV).Target, x.Target) + } + case *dns.RRSIG: + if x.TypeCovered != section[i].(*dns.RRSIG).TypeCovered { + return fmt.Errorf("RR %d should have a TypeCovered of %d, but has %d", i, section[i].(*dns.RRSIG).TypeCovered, x.TypeCovered) + } + if x.Labels != section[i].(*dns.RRSIG).Labels { + return fmt.Errorf("RR %d should have a Labels of %d, but has %d", i, section[i].(*dns.RRSIG).Labels, x.Labels) + } + if x.SignerName != section[i].(*dns.RRSIG).SignerName { + return fmt.Errorf("RR %d should have a SignerName of %s, but has %s", i, section[i].(*dns.RRSIG).SignerName, x.SignerName) + } + case *dns.NSEC: + if x.NextDomain != section[i].(*dns.NSEC).NextDomain { + return fmt.Errorf("RR %d should have a NextDomain of %s, but has %s", i, section[i].(*dns.NSEC).NextDomain, x.NextDomain) + } + // TypeBitMap + case *dns.A: + if x.A.String() != section[i].(*dns.A).A.String() { + return fmt.Errorf("RR %d should have a Address of %q, but has %q", i, section[i].(*dns.A).A.String(), x.A.String()) + } + case *dns.AAAA: + if x.AAAA.String() != section[i].(*dns.AAAA).AAAA.String() { + return fmt.Errorf("RR %d should have a Address of %q, but has %q", i, section[i].(*dns.AAAA).AAAA.String(), x.AAAA.String()) + } + case *dns.TXT: + for j, txt := range x.Txt { + if txt != section[i].(*dns.TXT).Txt[j] { + return fmt.Errorf("RR %d should have a Txt of %q, but has %q", i, section[i].(*dns.TXT).Txt[j], txt) + } + } + case *dns.HINFO: + if x.Cpu != section[i].(*dns.HINFO).Cpu { + return fmt.Errorf("RR %d should have a Cpu of %s, but has %s", i, section[i].(*dns.HINFO).Cpu, x.Cpu) + } + if x.Os != section[i].(*dns.HINFO).Os { + return fmt.Errorf("RR %d should have a Os of %s, but has %s", i, section[i].(*dns.HINFO).Os, x.Os) + } + case *dns.SOA: + tt := section[i].(*dns.SOA) + if x.Ns != tt.Ns { + return fmt.Errorf("SOA nameserver should be %q, but is %q", tt.Ns, x.Ns) + } + case *dns.PTR: + tt := section[i].(*dns.PTR) + if x.Ptr != tt.Ptr { + return fmt.Errorf("PTR ptr should be %q, but is %q", tt.Ptr, x.Ptr) + } + case *dns.CNAME: + tt := section[i].(*dns.CNAME) + if x.Target != tt.Target { + return fmt.Errorf("CNAME target should be %q, but is %q", tt.Target, x.Target) + } + case *dns.MX: + tt := section[i].(*dns.MX) + if x.Mx != tt.Mx { + return fmt.Errorf("MX Mx should be %q, but is %q", tt.Mx, x.Mx) + } + if x.Preference != tt.Preference { + return fmt.Errorf("MX Preference should be %q, but is %q", tt.Preference, x.Preference) + } + case *dns.NS: + tt := section[i].(*dns.NS) + if x.Ns != tt.Ns { + return fmt.Errorf("NS nameserver should be %q, but is %q", tt.Ns, x.Ns) + } + case *dns.OPT: + tt := section[i].(*dns.OPT) + if x.UDPSize() != tt.UDPSize() { + return fmt.Errorf("OPT UDPSize should be %d, but is %d", tt.UDPSize(), x.UDPSize()) + } + if x.Do() != tt.Do() { + return fmt.Errorf("OPT DO should be %t, but is %t", tt.Do(), x.Do()) + } + } + } + return nil +} + +// CNAMEOrder makes sure that CNAMES do not appear after their target records. +func CNAMEOrder(res *dns.Msg) error { + for i, c := range res.Answer { + if c.Header().Rrtype != dns.TypeCNAME { + continue + } + for _, a := range res.Answer[:i] { + if a.Header().Name != c.(*dns.CNAME).Target { + continue + } + return fmt.Errorf("CNAME found after target record") + } + } + return nil +} + +// SortAndCheck sorts resp and the checks the header and three sections against the testcase in tc. +func SortAndCheck(resp *dns.Msg, tc Case) error { + sort.Sort(RRSet(resp.Answer)) + sort.Sort(RRSet(resp.Ns)) + sort.Sort(RRSet(resp.Extra)) + + if err := Header(tc, resp); err != nil { + return err + } + if err := Section(tc, Answer, resp.Answer); err != nil { + return err + } + if err := Section(tc, Ns, resp.Ns); err != nil { + return err + } + return Section(tc, Extra, resp.Extra) +} + +// ErrorHandler returns a Handler that returns ServerFailure error when called. +func ErrorHandler() Handler { + return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(m) + return dns.RcodeServerFailure, nil + }) +} + +// NextHandler returns a Handler that returns rcode and err. +func NextHandler(rcode int, err error) Handler { + return HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return rcode, err + }) +} + +// Copied here to prevent an import cycle, so that we can define to above handlers. + +type ( + // HandlerFunc is a convenience type like dns.HandlerFunc, except + // ServeDNS returns an rcode and an error. + HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) + + // Handler interface defines a plugin. + Handler interface { + ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error) + Name() string + } +) + +// ServeDNS implements the Handler interface. +func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return f(ctx, w, r) +} + +// Name implements the Handler interface. +func (f HandlerFunc) Name() string { return "handlerfunc" } diff --git a/plugin/test/responsewriter.go b/plugin/test/responsewriter.go new file mode 100644 index 0000000..3216700 --- /dev/null +++ b/plugin/test/responsewriter.go @@ -0,0 +1,80 @@ +package test + +import ( + "net" + + "github.com/miekg/dns" +) + +// ResponseWriter is useful for writing tests. It uses some fixed values for the client. The +// remote will always be 10.240.0.1 and port 40212. The local address is always 127.0.0.1 and +// port 53. +type ResponseWriter struct { + TCP bool // if TCP is true we return an TCP connection instead of an UDP one. + RemoteIP string + Zone string +} + +// LocalAddr returns the local address, 127.0.0.1:53 (UDP, TCP if t.TCP is true). +func (t *ResponseWriter) LocalAddr() net.Addr { + ip := net.ParseIP("127.0.0.1") + port := 53 + if t.TCP { + return &net.TCPAddr{IP: ip, Port: port, Zone: ""} + } + return &net.UDPAddr{IP: ip, Port: port, Zone: ""} +} + +// RemoteAddr returns the remote address, defaults to 10.240.0.1:40212 (UDP, TCP is t.TCP is true). +func (t *ResponseWriter) RemoteAddr() net.Addr { + remoteIP := "10.240.0.1" + if t.RemoteIP != "" { + remoteIP = t.RemoteIP + } + ip := net.ParseIP(remoteIP) + port := 40212 + if t.TCP { + return &net.TCPAddr{IP: ip, Port: port, Zone: t.Zone} + } + return &net.UDPAddr{IP: ip, Port: port, Zone: t.Zone} +} + +// WriteMsg implements dns.ResponseWriter interface. +func (t *ResponseWriter) WriteMsg(m *dns.Msg) error { return nil } + +// Write implements dns.ResponseWriter interface. +func (t *ResponseWriter) Write(buf []byte) (int, error) { return len(buf), nil } + +// Close implements dns.ResponseWriter interface. +func (t *ResponseWriter) Close() error { return nil } + +// TsigStatus implements dns.ResponseWriter interface. +func (t *ResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly implements dns.ResponseWriter interface. +func (t *ResponseWriter) TsigTimersOnly(bool) {} + +// Hijack implements dns.ResponseWriter interface. +func (t *ResponseWriter) Hijack() {} + +// ResponseWriter6 returns fixed client and remote address in IPv6. The remote +// address is always fe80::42:ff:feca:4c65 and port 40212. The local address is always ::1 and port 53. +type ResponseWriter6 struct { + ResponseWriter +} + +// LocalAddr returns the local address, always ::1, port 53 (UDP, TCP is t.TCP is true). +func (t *ResponseWriter6) LocalAddr() net.Addr { + if t.TCP { + return &net.TCPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} + } + return &net.UDPAddr{IP: net.ParseIP("::1"), Port: 53, Zone: ""} +} + +// RemoteAddr returns the remote address, always fe80::42:ff:feca:4c65 port 40212 (UDP, TCP is t.TCP is true). +func (t *ResponseWriter6) RemoteAddr() net.Addr { + if t.TCP { + return &net.TCPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} + } + return &net.UDPAddr{IP: net.ParseIP("fe80::42:ff:feca:4c65"), Port: 40212, Zone: ""} +} diff --git a/plugin/test/scrape.go b/plugin/test/scrape.go new file mode 100644 index 0000000..7ac22d5 --- /dev/null +++ b/plugin/test/scrape.go @@ -0,0 +1,262 @@ +// Adapted by Miek Gieben for CoreDNS testing. +// +// License from prom2json +// Copyright 2014 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test will scrape a target and you can inspect the variables. +// Basic usage: +// +// result := Scrape("http://localhost:9153/metrics") +// v := MetricValue("coredns_cache_capacity", result) +package test + +import ( + "fmt" + "io" + "mime" + "net/http" + "strconv" + + "github.com/matttproud/golang_protobuf_extensions/pbutil" + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" +) + +type ( + // MetricFamily holds a prometheus metric. + MetricFamily struct { + Name string `json:"name"` + Help string `json:"help"` + Type string `json:"type"` + Metrics []interface{} `json:"metrics,omitempty"` // Either metric or summary. + } + + // metric is for all "single value" metrics. + metric struct { + Labels map[string]string `json:"labels,omitempty"` + Value string `json:"value"` + } + + summary struct { + Labels map[string]string `json:"labels,omitempty"` + Quantiles map[string]string `json:"quantiles,omitempty"` + Count string `json:"count"` + Sum string `json:"sum"` + } + + histogram struct { + Labels map[string]string `json:"labels,omitempty"` + Buckets map[string]string `json:"buckets,omitempty"` + Count string `json:"count"` + Sum string `json:"sum"` + } +) + +// Scrape returns the all the vars a []*metricFamily. +func Scrape(url string) []*MetricFamily { + mfChan := make(chan *dto.MetricFamily, 1024) + + go fetchMetricFamilies(url, mfChan) + + result := []*MetricFamily{} + for mf := range mfChan { + result = append(result, newMetricFamily(mf)) + } + return result +} + +// ScrapeMetricAsInt provides a sum of all metrics collected for the name and label provided. +// if the metric is not a numeric value, it will be counted a 0. +func ScrapeMetricAsInt(addr string, name string, label string, nometricvalue int) int { + valueToInt := func(m metric) int { + v := m.Value + r, err := strconv.Atoi(v) + if err != nil { + return 0 + } + return r + } + + met := Scrape(fmt.Sprintf("http://%s/metrics", addr)) + found := false + tot := 0 + for _, mf := range met { + if mf.Name == name { + // Sum all metrics available + for _, m := range mf.Metrics { + if label == "" { + tot += valueToInt(m.(metric)) + found = true + continue + } + for _, v := range m.(metric).Labels { + if v == label { + tot += valueToInt(m.(metric)) + found = true + } + } + } + } + } + + if !found { + return nometricvalue + } + return tot +} + +// MetricValue returns the value associated with name as a string as well as the labels. +// It only returns the first metrics of the slice. +func MetricValue(name string, mfs []*MetricFamily) (string, map[string]string) { + for _, mf := range mfs { + if mf.Name == name { + // Only works with Gauge and Counter... + return mf.Metrics[0].(metric).Value, mf.Metrics[0].(metric).Labels + } + } + return "", nil +} + +// MetricValueLabel returns the value for name *and* label *value*. +func MetricValueLabel(name, label string, mfs []*MetricFamily) (string, map[string]string) { + // bit hacky is this really handy...? + for _, mf := range mfs { + if mf.Name == name { + for _, m := range mf.Metrics { + for _, v := range m.(metric).Labels { + if v == label { + return m.(metric).Value, m.(metric).Labels + } + } + } + } + } + return "", nil +} + +func newMetricFamily(dtoMF *dto.MetricFamily) *MetricFamily { + mf := &MetricFamily{ + Name: dtoMF.GetName(), + Help: dtoMF.GetHelp(), + Type: dtoMF.GetType().String(), + Metrics: make([]interface{}, len(dtoMF.Metric)), + } + for i, m := range dtoMF.Metric { + if dtoMF.GetType() == dto.MetricType_SUMMARY { + mf.Metrics[i] = summary{ + Labels: makeLabels(m), + Quantiles: makeQuantiles(m), + Count: fmt.Sprint(m.GetSummary().GetSampleCount()), + Sum: fmt.Sprint(m.GetSummary().GetSampleSum()), + } + } else if dtoMF.GetType() == dto.MetricType_HISTOGRAM { + mf.Metrics[i] = histogram{ + Labels: makeLabels(m), + Buckets: makeBuckets(m), + Count: fmt.Sprint(m.GetHistogram().GetSampleCount()), + Sum: fmt.Sprint(m.GetSummary().GetSampleSum()), + } + } else { + mf.Metrics[i] = metric{ + Labels: makeLabels(m), + Value: fmt.Sprint(value(m)), + } + } + } + return mf +} + +func value(m *dto.Metric) float64 { + if m.Gauge != nil { + return m.GetGauge().GetValue() + } + if m.Counter != nil { + return m.GetCounter().GetValue() + } + if m.Untyped != nil { + return m.GetUntyped().GetValue() + } + return 0. +} + +func makeLabels(m *dto.Metric) map[string]string { + result := map[string]string{} + for _, lp := range m.Label { + result[lp.GetName()] = lp.GetValue() + } + return result +} + +func makeQuantiles(m *dto.Metric) map[string]string { + result := map[string]string{} + for _, q := range m.GetSummary().Quantile { + result[fmt.Sprint(q.GetQuantile())] = fmt.Sprint(q.GetValue()) + } + return result +} + +func makeBuckets(m *dto.Metric) map[string]string { + result := map[string]string{} + for _, b := range m.GetHistogram().Bucket { + result[fmt.Sprint(b.GetUpperBound())] = fmt.Sprint(b.GetCumulativeCount()) + } + return result +} + +func fetchMetricFamilies(url string, ch chan<- *dto.MetricFamily) { + defer close(ch) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return + } + req.Header.Add("Accept", acceptHeader) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return + } + + mediatype, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err == nil && mediatype == "application/vnd.google.protobuf" && + params["encoding"] == "delimited" && + params["proto"] == "io.prometheus.client.MetricFamily" { + for { + mf := &dto.MetricFamily{} + if _, err = pbutil.ReadDelimited(resp.Body, mf); err != nil { + if err == io.EOF { + break + } + return + } + ch <- mf + } + } else { + // We could do further content-type checks here, but the + // fallback for now will anyway be the text format + // version 0.0.4, so just go for it and see if it works. + var parser expfmt.TextParser + metricFamilies, err := parser.TextToMetricFamilies(resp.Body) + if err != nil { + return + } + for _, mf := range metricFamilies { + ch <- mf + } + } +} + +const acceptHeader = `application/vnd.google.protobuf;proto=io.prometheus.client.MetricFamily;encoding=delimited;q=0.7,text/plain;version=0.0.4;q=0.3` diff --git a/plugin/timeouts/README.md b/plugin/timeouts/README.md new file mode 100644 index 0000000..098c9cc --- /dev/null +++ b/plugin/timeouts/README.md @@ -0,0 +1,76 @@ +# timeouts + +## Name + +*timeouts* - allows you to configure the server read, write and idle timeouts for the TCP, TLS and DoH servers. + +## Description + +CoreDNS is configured with sensible timeouts for server connections by default. +However in some cases for example where CoreDNS is serving over a slow mobile +data connection the default timeouts are not optimal. + +Additionally some routers hold open connections when using DNS over TLS or DNS +over HTTPS. Allowing a longer idle timeout helps performance and reduces issues +with such routers. + +The *timeouts* "plugin" allows you to configure CoreDNS server read, write and +idle timeouts. + +## Syntax + +~~~ txt +timeouts { + read DURATION + write DURATION + idle DURATION +} +~~~ + +For any timeouts that are not provided, default values are used which may vary +depending on the server type. At least one timeout must be specified otherwise +the entire timeouts block should be omitted. + +## Examples + +Start a DNS-over-TLS server that picks up incoming DNS-over-TLS queries on port +5553 and uses the nameservers defined in `/etc/resolv.conf` to resolve the +query. This proxy path uses plain old DNS. A 10 second read timeout, 20 +second write timeout and a 60 second idle timeout have been configured. + +~~~ +tls://.:5553 { + tls cert.pem key.pem ca.pem + timeouts { + read 10s + write 20s + idle 60s + } + forward . /etc/resolv.conf +} +~~~ + +Start a DNS-over-HTTPS server that is similar to the previous example. Only the +read timeout has been configured for 1 minute. + +~~~ +https://. { + tls cert.pem key.pem ca.pem + timeouts { + read 1m + } + forward . /etc/resolv.conf +} +~~~ + +Start a standard TCP/UDP server on port 1053. A read and write timeout has been +configured. The timeouts are only applied to the TCP side of the server. +~~~ +.:1053 { + timeouts { + read 15s + write 30s + } + forward . /etc/resolv.conf +} +~~~ diff --git a/plugin/timeouts/timeouts.go b/plugin/timeouts/timeouts.go new file mode 100644 index 0000000..eea6a64 --- /dev/null +++ b/plugin/timeouts/timeouts.go @@ -0,0 +1,69 @@ +package timeouts + +import ( + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/durations" +) + +func init() { plugin.Register("timeouts", setup) } + +func setup(c *caddy.Controller) error { + err := parseTimeouts(c) + if err != nil { + return plugin.Error("timeouts", err) + } + return nil +} + +func parseTimeouts(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + for c.Next() { + args := c.RemainingArgs() + if len(args) > 0 { + return plugin.Error("timeouts", c.ArgErr()) + } + + b := 0 + for c.NextBlock() { + block := c.Val() + timeoutArgs := c.RemainingArgs() + if len(timeoutArgs) != 1 { + return c.ArgErr() + } + + timeout, err := durations.NewDurationFromArg(timeoutArgs[0]) + if err != nil { + return c.Err(err.Error()) + } + + if timeout < (1*time.Second) || timeout > (24*time.Hour) { + return c.Errf("timeout provided '%s' needs to be between 1 second and 24 hours", timeout) + } + + switch block { + case "read": + config.ReadTimeout = timeout + + case "write": + config.WriteTimeout = timeout + + case "idle": + config.IdleTimeout = timeout + + default: + return c.Errf("unknown option: '%s'", block) + } + b++ + } + + if b == 0 { + return plugin.Error("timeouts", c.Err("timeouts block with no timeouts specified")) + } + } + return nil +} diff --git a/plugin/timeouts/timeouts_test.go b/plugin/timeouts/timeouts_test.go new file mode 100644 index 0000000..c01d3a0 --- /dev/null +++ b/plugin/timeouts/timeouts_test.go @@ -0,0 +1,75 @@ +package timeouts + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestTimeouts(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedRoot string // expected root, set to the controller. Empty for negative cases. + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + {`timeouts { + read 30s + }`, false, "", ""}, + {`timeouts { + read 1m + write 2m + }`, false, "", ""}, + {` timeouts { + idle 1h + }`, false, "", ""}, + {`timeouts { + read 10 + write 20 + idle 60 + }`, false, "", ""}, + // negative + {`timeouts`, true, "", "block with no timeouts specified"}, + {`timeouts { + }`, true, "", "block with no timeouts specified"}, + {`timeouts { + read 10s + giraffe 30s + }`, true, "", "unknown option"}, + {`timeouts { + read 10s 20s + write 30s + }`, true, "", "Wrong argument"}, + {`timeouts { + write snake + }`, true, "", "failed to parse duration"}, + {`timeouts { + idle 0s + }`, true, "", "needs to be between"}, + {`timeouts { + read 48h + }`, true, "", "needs to be between"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + //cfg := dnsserver.GetConfig(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + } +} diff --git a/plugin/tls/README.md b/plugin/tls/README.md new file mode 100644 index 0000000..9d945b8 --- /dev/null +++ b/plugin/tls/README.md @@ -0,0 +1,73 @@ +# tls + +## Name + +*tls* - allows you to configure the server certificates for the TLS, gRPC, DoH servers. + +## Description + +CoreDNS supports queries that are encrypted using TLS (DNS over Transport Layer Security, RFC 7858) +or are using gRPC (https://grpc.io/, not an IETF standard). Normally DNS traffic isn't encrypted at +all (DNSSEC only signs resource records). + +The *tls* "plugin" allows you to configure the cryptographic keys that are needed for both +DNS-over-TLS and DNS-over-gRPC. If the *tls* plugin is omitted, then no encryption takes place. + +The gRPC protobuffer is defined in `pb/dns.proto`. It defines the proto as a simple wrapper for the +wire data of a DNS message. + +## Syntax + +~~~ txt +tls CERT KEY [CA] +~~~ + +Parameter CA is optional. If not set, system CAs can be used to verify the client certificate + +~~~ txt +tls CERT KEY [CA] { + client_auth nocert|request|require|verify_if_given|require_and_verify +} +~~~ + +If client\_auth option is specified, it controls the client authentication policy. +The option value corresponds to the [ClientAuthType values of the Go tls package](https://golang.org/pkg/crypto/tls/#ClientAuthType): NoClientCert, RequestClientCert, RequireAnyClientCert, VerifyClientCertIfGiven, and RequireAndVerifyClientCert, respectively. +The default is "nocert". Note that it makes no sense to specify parameter CA unless this option is +set to verify\_if\_given or require\_and\_verify. + +## Examples + +Start a DNS-over-TLS server that picks up incoming DNS-over-TLS queries on port 5553 and uses the +nameservers defined in `/etc/resolv.conf` to resolve the query. This proxy path uses plain old DNS. + +~~~ +tls://.:5553 { + tls cert.pem key.pem ca.pem + forward . /etc/resolv.conf +} +~~~ + +Start a DNS-over-gRPC server that is similar to the previous example, but using DNS-over-gRPC for +incoming queries. + +~~~ +grpc://. { + tls cert.pem key.pem ca.pem + forward . /etc/resolv.conf +} +~~~ + +Start a DoH server on port 443 that is similar to the previous example, but using DoH for incoming queries. +~~~ +https://. { + tls cert.pem key.pem ca.pem + forward . /etc/resolv.conf +} +~~~ + +Only Knot DNS' `kdig` supports DNS-over-TLS queries, no command line client supports gRPC making +debugging these transports harder than it should be. + +## See Also + +RFC 7858 and https://grpc.io. diff --git a/plugin/tls/log_test.go b/plugin/tls/log_test.go new file mode 100644 index 0000000..017affd --- /dev/null +++ b/plugin/tls/log_test.go @@ -0,0 +1,5 @@ +package tls + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/tls/test_ca.pem b/plugin/tls/test_ca.pem new file mode 100644 index 0000000..cfcd5cc --- /dev/null +++ b/plugin/tls/test_ca.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDPzCCAiegAwIBAgIJAPjCWTu1wGapMA0GCSqGSIb3DQEBCwUAMDUxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMREwDwYDVQQKDAhJbmZvYmxveDAg +Fw0xOTA1MTEwMDI3NDRaGA8yMTE5MDQxNzAwMjc0NFowNTELMAkGA1UEBhMCVVMx +EzARBgNVBAgMCkNhbGlmb3JuaWExETAPBgNVBAoMCEluZm9ibG94MIIBIjANBgkq +hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArAYiw1UjlYj+nITRUlj5hA7j8U2qWcyN +YcDfqQnt173Z8yR7NJokqt3Bd3PlrBZS2XtYSNohxRr4qeJu/g7UBre/fSEU/ZOM +Gl7NjBGKQEymJ0d8rBg52iiGNwU+ERI9pcQRA6DCEjVbOmjDiUd5yzuVotG/Sxep +GUJ2puJ0p0gWCMEL9sdqY6HHd/hdj6B6+u2xD9NUCkX9pLC7CPFJHnP0vLO4WIWL +z5C7yzpeLO9r7Nfnu+2HcRLmuFZVPNxkMq7UymqR1w5ZYJQ5p9E7pyxDVXxHnTqQ +yLaAS2/9umrOwVnD1NaN3OdAhDedXbH0cF08GcIQD9rnlkLMW4CKtwIDAQABo1Aw +TjAdBgNVHQ4EFgQUHcxJPBmHF0nSv+FJJI/kwrSThf8wHwYDVR0jBBgwFoAUHcxJ +PBmHF0nSv+FJJI/kwrSThf8wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC +AQEAByItgyhlXDv2wnnMVXHHlUCbsKCOtBJZ8EumvKjeOx5G4gqJpQIQPNeBv1Od +QT7d15HfT7RQqHSL0uAoGuNuyGjZGWWbLMkVt8T0tXY2v9Dd8eWC/lFaaA0vkqTG +GpADSmH+SoFAdPPcYN/sXmEHvZcIQ0wUxuF48ZMwOh7ZOcrZggxlA9+BKHU4fO03 +o7krzpQZQmEDXNN8bt1R0DIhVADw/G2oJAzK0LGhh4eu6hj6k/cAWS6ujRBGqN0Z +fURCrMEyjzbNybhkU1KqSr7eSJOWkl4UJ5Ns/dt9/yw2BBrKH3Mijch7UA8mlbEE +29M28u2W7GMXLSSwmtCqDBRNhg== +-----END CERTIFICATE----- diff --git a/plugin/tls/test_cert.pem b/plugin/tls/test_cert.pem new file mode 100644 index 0000000..8cc47eb --- /dev/null +++ b/plugin/tls/test_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDPzCCAiegAwIBAgIJAPezzzshGRiTMA0GCSqGSIb3DQEBCwUAMDUxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMREwDwYDVQQKDAhJbmZvYmxveDAg +Fw0xOTA1MTEwMDI2MjNaGA8yMTE5MDQxNzAwMjYyM1owNTELMAkGA1UEBhMCVVMx +EzARBgNVBAgMCkNhbGlmb3JuaWExETAPBgNVBAoMCEluZm9ibG94MIIBIjANBgkq +hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArAYiw1UjlYj+nITRUlj5hA7j8U2qWcyN +YcDfqQnt173Z8yR7NJokqt3Bd3PlrBZS2XtYSNohxRr4qeJu/g7UBre/fSEU/ZOM +Gl7NjBGKQEymJ0d8rBg52iiGNwU+ERI9pcQRA6DCEjVbOmjDiUd5yzuVotG/Sxep +GUJ2puJ0p0gWCMEL9sdqY6HHd/hdj6B6+u2xD9NUCkX9pLC7CPFJHnP0vLO4WIWL +z5C7yzpeLO9r7Nfnu+2HcRLmuFZVPNxkMq7UymqR1w5ZYJQ5p9E7pyxDVXxHnTqQ +yLaAS2/9umrOwVnD1NaN3OdAhDedXbH0cF08GcIQD9rnlkLMW4CKtwIDAQABo1Aw +TjAdBgNVHQ4EFgQUHcxJPBmHF0nSv+FJJI/kwrSThf8wHwYDVR0jBBgwFoAUHcxJ +PBmHF0nSv+FJJI/kwrSThf8wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC +AQEAQyN9nLImdtufuSjXcrCJ3alt/vffHJIzlPgDsNw8+tjI7aRX7CzuurOOEQUC +fJ9A6O+dat5k5yqVb9hDcD42HXtOjRQDYpQ6dOGirLFThIFSMC/7RiqHk0YtxojM +ZNBbgXo4o1d+P9b25oc/+pRDzlOvqNL7IzW/LDHnJ4j6tBNguujCB5QFUF5dOa1z +UR5rupMvv2KpEgRcfW/d3kwcAxH9nI0SHKJenhtweyajUgInK88TC+aT4909c2XA +EADYyWxj1DMz3/sMpvGegHsfTPegNoDgz2yEKdu53dr4BUpF6E+eoCX9Hv78SWH3 +/rAlkbffzCL5d+I8y0jzEpLEqA== +-----END CERTIFICATE----- diff --git a/plugin/tls/test_key.pem b/plugin/tls/test_key.pem new file mode 100644 index 0000000..2ca2b21 --- /dev/null +++ b/plugin/tls/test_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCsBiLDVSOViP6c +hNFSWPmEDuPxTapZzI1hwN+pCe3XvdnzJHs0miSq3cF3c+WsFlLZe1hI2iHFGvip +4m7+DtQGt799IRT9k4waXs2MEYpATKYnR3ysGDnaKIY3BT4REj2lxBEDoMISNVs6 +aMOJR3nLO5Wi0b9LF6kZQnam4nSnSBYIwQv2x2pjocd3+F2PoHr67bEP01QKRf2k +sLsI8Ukec/S8s7hYhYvPkLvLOl4s72vs1+e77YdxEua4VlU83GQyrtTKapHXDllg +lDmn0TunLENVfEedOpDItoBLb/26as7BWcPU1o3c50CEN51dsfRwXTwZwhAP2ueW +QsxbgIq3AgMBAAECggEAF3FCnYHltoQTxnqnF+S+JAvvbjvaQiCJB9BD6oJK4kKi +B+tpytJSuuI7ci7eFqR4J+ESN+NaBMVXK7eKzp5wsHWr575xYNkRl6phsnvVbkvD +vMiWKdGnWJ57I9ZYDfWBZyyf8PGgYODajMwoEXYnF9YH30dcHTydM68GAloL8Zu9 +CtGCmlu4TER0BvG+rK2OD5lt8ORK56eMwzTTqMy0hCkP5VEq8j9RmekEzrgtWKm8 +OI3i8VnpOA0RCVhJ0q5a5jt/xbKRjFNsUNmy9HBRYg7Iw3SCEHmDtz1R9A9rvaJC +WXqwKbGZPY8W69h8BhKcJ5RrKt2PZyJxw+LB610XSQKBgQDR/LIGXdJR/90epiGC +p68W9Vc3eWxJlAtLDQCSULphLi6j7D+jesmhD3z2woBPjxkd4TaZa2t94Q1MzSeC +ON/Aux1huto9ddxvijUQJN3Ep4zPkHdNzHfRwIZsgGH8u77VY/5I4V7IgxKjWlJ6 +Ii8ez8xpWj1rnQ0azSaYIcVl7QKBgQDRt+J+iRjKxHWuXoBFfv8oMfl+iYaMdJxu +PELWb3RLsZ92hobSAmNR/gC3T7p8NFJlQVCoxZr8zt/Rvqh4aK3aSOuKeUvYAjs1 +/YbPcdSn6uTTIOi6CcHaJ8ZUXNvY5FuoT0+Q9Eb8fw5NGzxsgsfhScELLgbFKb5E +Tkw43ZqeswKBgQCxXBgZnIEaVVw0mOlQ68TNRWfnKR23f92SBGdpLdpeXp1yQwb1 +U66d5PENkvbBPAJg5GozZzGhXsbXCajHKraCmQiWFTZkFvqbE0cCXcEaatJaNpEu +GvdRKKXhWwZoa0MiBZUvhXuDLII/iviCxAC8q5LhoSCjlkENVB22/T83eQKBgQC4 +c3wRALG+fWZns5QsC5ONnc6rXXfqhxGi3vuGMMbfYF05WP6xLQp/7eBhWg1R+o7R +oc24cvxrB+TRTFhOdvsZtvL7es2bMfMz/EUapSp9edpCW3p1Temi30LPplByhf6b +nQ4FFuRsZa+FX8QYSDpWypCwLY4k0R8YYqklhrrcgwKBgFiM/GnRc230nj0GGWf1 ++Ve2M/TQCgS6ufr2F0vU7QkEWfeiN9iunhmhsggqWxOEOU77FhCkQRtztm93hG0K +eKoHNh/1HvHGBWsR0TaMDw3n8t7Yg5NmQb617nBELZbxxpd358muLiHDoix86W9Q +xM6hB159G1gOEJsi8exm5AlZ +-----END PRIVATE KEY----- diff --git a/plugin/tls/tls.go b/plugin/tls/tls.go new file mode 100644 index 0000000..ff60b67 --- /dev/null +++ b/plugin/tls/tls.go @@ -0,0 +1,77 @@ +package tls + +import ( + ctls "crypto/tls" + "path/filepath" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/tls" +) + +func init() { plugin.Register("tls", setup) } + +func setup(c *caddy.Controller) error { + err := parseTLS(c) + if err != nil { + return plugin.Error("tls", err) + } + return nil +} + +func parseTLS(c *caddy.Controller) error { + config := dnsserver.GetConfig(c) + + if config.TLSConfig != nil { + return plugin.Error("tls", c.Errf("TLS already configured for this server instance")) + } + + for c.Next() { + args := c.RemainingArgs() + if len(args) < 2 || len(args) > 3 { + return plugin.Error("tls", c.ArgErr()) + } + clientAuth := ctls.NoClientCert + for c.NextBlock() { + switch c.Val() { + case "client_auth": + authTypeArgs := c.RemainingArgs() + if len(authTypeArgs) != 1 { + return c.ArgErr() + } + switch authTypeArgs[0] { + case "nocert": + clientAuth = ctls.NoClientCert + case "request": + clientAuth = ctls.RequestClientCert + case "require": + clientAuth = ctls.RequireAnyClientCert + case "verify_if_given": + clientAuth = ctls.VerifyClientCertIfGiven + case "require_and_verify": + clientAuth = ctls.RequireAndVerifyClientCert + default: + return c.Errf("unknown authentication type '%s'", authTypeArgs[0]) + } + default: + return c.Errf("unknown option '%s'", c.Val()) + } + } + for i := range args { + if !filepath.IsAbs(args[i]) && config.Root != "" { + args[i] = filepath.Join(config.Root, args[i]) + } + } + tls, err := tls.NewTLSConfigFromArgs(args...) + if err != nil { + return err + } + tls.ClientAuth = clientAuth + // NewTLSConfigFromArgs only sets RootCAs, so we need to let ClientCAs refer to it. + tls.ClientCAs = tls.RootCAs + + config.TLSConfig = tls + } + return nil +} diff --git a/plugin/tls/tls_test.go b/plugin/tls/tls_test.go new file mode 100644 index 0000000..7deb837 --- /dev/null +++ b/plugin/tls/tls_test.go @@ -0,0 +1,87 @@ +package tls + +import ( + "crypto/tls" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" +) + +func TestTLS(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedRoot string // expected root, set to the controller. Empty for negative cases. + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + {"tls test_cert.pem test_key.pem test_ca.pem", false, "", ""}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth nocert\n}", false, "", ""}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth request\n}", false, "", ""}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth require\n}", false, "", ""}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth verify_if_given\n}", false, "", ""}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth require_and_verify\n}", false, "", ""}, + // negative + {"tls test_cert.pem test_key.pem test_ca.pem {\nunknown\n}", true, "", "unknown option"}, + // client_auth takes exactly one parameter, which must be one of known keywords. + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth\n}", true, "", "Wrong argument"}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth none bogus\n}", true, "", "Wrong argument"}, + {"tls test_cert.pem test_key.pem test_ca.pem {\nclient_auth bogus\n}", true, "", "unknown authentication type"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + err := setup(c) + //cfg := dnsserver.GetConfig(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + } +} + +func TestTLSClientAuthentication(t *testing.T) { + // Invalid configurations are tested in the general test case. In this test we only look into specific details of valid client_auth options. + tests := []struct { + option string // tls plugin option(s) + expectedType tls.ClientAuthType // expected authentication type. + }{ + // By default, or if 'nocert' is specified, no cert should be requested. + // Other cases should be a straightforward mapping from the keyword to the type value. + {"", tls.NoClientCert}, + {"{\nclient_auth nocert\n}", tls.NoClientCert}, + {"{\nclient_auth request\n}", tls.RequestClientCert}, + {"{\nclient_auth require\n}", tls.RequireAnyClientCert}, + {"{\nclient_auth verify_if_given\n}", tls.VerifyClientCertIfGiven}, + {"{\nclient_auth require_and_verify\n}", tls.RequireAndVerifyClientCert}, + } + + for i, test := range tests { + input := "tls test_cert.pem test_key.pem test_ca.pem " + test.option + c := caddy.NewTestController("dns", input) + err := setup(c) + if err != nil { + t.Errorf("Test %d: TLS config is unexpectedly rejected: %v", i, err) + continue // there's no point in the rest of the tests. + } + cfg := dnsserver.GetConfig(c) + if cfg.TLSConfig.ClientCAs == nil { + t.Errorf("Test %d: Client CA is not configured", i) + } + if cfg.TLSConfig.ClientAuth != test.expectedType { + t.Errorf("Test %d: Unexpected client auth type: %d", i, cfg.TLSConfig.ClientAuth) + } + } +} diff --git a/plugin/trace/README.md b/plugin/trace/README.md new file mode 100644 index 0000000..eac8a7b --- /dev/null +++ b/plugin/trace/README.md @@ -0,0 +1,113 @@ +# trace + +## Name + +*trace* - enables OpenTracing-based tracing of DNS requests as they go through the plugin chain. + +## Description + +With *trace* you enable OpenTracing of how a request flows through CoreDNS. Enable the *debug* +plugin to get logs from the trace plugin. + +## Syntax + +The simplest form is just: + +~~~ +trace [ENDPOINT-TYPE] [ENDPOINT] +~~~ + +* **ENDPOINT-TYPE** is the type of tracing destination. Currently only `zipkin` and `datadog` are supported. + Defaults to `zipkin`. +* **ENDPOINT** is the tracing destination, and defaults to `localhost:9411`. For Zipkin, if + **ENDPOINT** does not begin with `http`, then it will be transformed to `http://ENDPOINT/api/v1/spans`. + +With this form, all queries will be traced. + +Additional features can be enabled with this syntax: + +~~~ +trace [ENDPOINT-TYPE] [ENDPOINT] { + every AMOUNT + service NAME + client_server + datadog_analytics_rate RATE + zipkin_max_backlog_size SIZE + zipkin_max_batch_size SIZE + zipkin_max_batch_interval DURATION +} +~~~ + +* `every` **AMOUNT** will only trace one query of each AMOUNT queries. For example, to trace 1 in every + 100 queries, use AMOUNT of 100. The default is 1. +* `service` **NAME** allows you to specify the service name reported to the tracing server. + Default is `coredns`. +* `client_server` will enable the `ClientServerSameSpan` OpenTracing feature. +* `datadog_analytics_rate` **RATE** will enable [trace analytics](https://docs.datadoghq.com/tracing/app_analytics) on the traces sent + from *0* to *1*, *1* being every trace sent will be analyzed. This is a datadog only feature + (**ENDPOINT-TYPE** needs to be `datadog`) +* `zipkin_max_backlog_size` configures the maximum backlog size for Zipkin HTTP reporter. When batch size reaches this threshold, + spans from the beginning of the batch will be disposed. Default is 1000 backlog size. +* `zipkin_max_batch_size` configures the maximum batch size for Zipkin HTTP reporter, after which a collect will be triggered. The default batch size is 100 traces. +* `zipkin_max_batch_interval` configures the maximum duration we will buffer traces before emitting them to the collector using Zipkin HTTP reporter. + The default batch interval is 1 second. + +## Zipkin + +You can run Zipkin on a Docker host like this: + +``` +docker run -d -p 9411:9411 openzipkin/zipkin +``` + +Note the zipkin provider does not support the v1 API since coredns 1.7.1. + +## Examples + +Use an alternative Zipkin address: + +~~~ +trace tracinghost:9253 +~~~ + +or + +~~~ corefile +. { + trace zipkin tracinghost:9253 +} +~~~ + +If for some reason you are using an API reverse proxy or something and need to remap +the standard Zipkin URL you can do something like: + +~~~ +trace http://tracinghost:9411/zipkin/api/v1/spans +~~~ + +Using DataDog: + +~~~ +trace datadog localhost:8126 +~~~ + +Trace one query every 10000 queries, rename the service, and enable same span: + +~~~ +trace tracinghost:9411 { + every 10000 + service dnsproxy + client_server +} +~~~ + +## Metadata + +The trace plugin will publish the following metadata, if the *metadata* +plugin is also enabled: + +* `trace/traceid`: identifier of (zipkin/datadog) trace of processed request + +## See Also + +See the *debug* plugin for more information about debug logging. diff --git a/plugin/trace/log_test.go b/plugin/trace/log_test.go new file mode 100644 index 0000000..a0fe761 --- /dev/null +++ b/plugin/trace/log_test.go @@ -0,0 +1,5 @@ +package trace + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/trace/logger.go b/plugin/trace/logger.go new file mode 100644 index 0000000..6499387 --- /dev/null +++ b/plugin/trace/logger.go @@ -0,0 +1,20 @@ +package trace + +import ( + clog "github.com/coredns/coredns/plugin/pkg/log" +) + +// loggerAdapter is a simple adapter around plugin logger made to implement io.Writer and ddtrace.Logger interface +// in order to log errors from span reporters as warnings +type loggerAdapter struct { + clog.P +} + +func (l *loggerAdapter) Write(p []byte) (n int, err error) { + l.P.Warning(string(p)) + return len(p), nil +} + +func (l *loggerAdapter) Log(msg string) { + l.P.Warning(msg) +} diff --git a/plugin/trace/setup.go b/plugin/trace/setup.go new file mode 100644 index 0000000..8672dcc --- /dev/null +++ b/plugin/trace/setup.go @@ -0,0 +1,163 @@ +package trace + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("trace", setup) } + +func setup(c *caddy.Controller) error { + t, err := traceParse(c) + if err != nil { + return plugin.Error("trace", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + t.Next = next + return t + }) + + c.OnStartup(t.OnStartup) + + return nil +} + +func traceParse(c *caddy.Controller) (*trace, error) { + var ( + tr = &trace{every: 1, serviceName: defServiceName} + err error + ) + + cfg := dnsserver.GetConfig(c) + if cfg.ListenHosts[0] != "" { + tr.serviceEndpoint = cfg.ListenHosts[0] + ":" + cfg.Port + } + + for c.Next() { // trace + var err error + args := c.RemainingArgs() + switch len(args) { + case 0: + tr.EndpointType, tr.Endpoint, err = normalizeEndpoint(defEpType, "") + case 1: + tr.EndpointType, tr.Endpoint, err = normalizeEndpoint(defEpType, args[0]) + case 2: + epType := strings.ToLower(args[0]) + tr.EndpointType, tr.Endpoint, err = normalizeEndpoint(epType, args[1]) + default: + err = c.ArgErr() + } + if err != nil { + return tr, err + } + for c.NextBlock() { + switch c.Val() { + case "every": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + tr.every, err = strconv.ParseUint(args[0], 10, 64) + if err != nil { + return nil, err + } + case "service": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + tr.serviceName = args[0] + case "client_server": + args := c.RemainingArgs() + if len(args) > 1 { + return nil, c.ArgErr() + } + tr.clientServer = true + if len(args) == 1 { + tr.clientServer, err = strconv.ParseBool(args[0]) + } + if err != nil { + return nil, err + } + case "datadog_analytics_rate": + args := c.RemainingArgs() + if len(args) > 1 { + return nil, c.ArgErr() + } + tr.datadogAnalyticsRate = 0 + if len(args) == 1 { + tr.datadogAnalyticsRate, err = strconv.ParseFloat(args[0], 64) + } + if err != nil { + return nil, err + } + if tr.datadogAnalyticsRate > 1 || tr.datadogAnalyticsRate < 0 { + return nil, fmt.Errorf("datadog analytics rate must be between 0 and 1, '%f' is not supported", tr.datadogAnalyticsRate) + } + case "zipkin_max_backlog_size": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + tr.zipkinMaxBacklogSize, err = strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + case "zipkin_max_batch_size": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + tr.zipkinMaxBatchSize, err = strconv.Atoi(args[0]) + if err != nil { + return nil, err + } + case "zipkin_max_batch_interval": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + tr.zipkinMaxBatchInterval, err = time.ParseDuration(args[0]) + if err != nil { + return nil, err + } + } + } + } + return tr, err +} + +func normalizeEndpoint(epType, ep string) (string, string, error) { + if _, ok := supportedProviders[epType]; !ok { + return "", "", fmt.Errorf("tracing endpoint type '%s' is not supported", epType) + } + + if ep == "" { + ep = supportedProviders[epType] + } + + if epType == "zipkin" { + if !strings.Contains(ep, "http") { + ep = "http://" + ep + "/api/v2/spans" + } + } + + return epType, ep, nil +} + +var supportedProviders = map[string]string{ + "zipkin": "localhost:9411", + "datadog": "localhost:8126", +} + +const ( + defEpType = "zipkin" + defServiceName = "coredns" +) diff --git a/plugin/trace/setup_test.go b/plugin/trace/setup_test.go new file mode 100644 index 0000000..72de4ab --- /dev/null +++ b/plugin/trace/setup_test.go @@ -0,0 +1,88 @@ +package trace + +import ( + "testing" + "time" + + "github.com/coredns/caddy" +) + +func TestTraceParse(t *testing.T) { + tests := []struct { + input string + shouldErr bool + endpoint string + every uint64 + serviceName string + clientServer bool + zipkinMaxBacklogSize int + zipkinMaxBatchSize int + zipkinMaxBatchInterval time.Duration + }{ + // oks + {`trace`, false, "http://localhost:9411/api/v2/spans", 1, `coredns`, false, 0, 0, 0}, + {`trace localhost:1234`, false, "http://localhost:1234/api/v2/spans", 1, `coredns`, false, 0, 0, 0}, + {`trace http://localhost:1234/somewhere/else`, false, "http://localhost:1234/somewhere/else", 1, `coredns`, false, 0, 0, 0}, + {`trace zipkin localhost:1234`, false, "http://localhost:1234/api/v2/spans", 1, `coredns`, false, 0, 0, 0}, + {`trace datadog localhost`, false, "localhost", 1, `coredns`, false, 0, 0, 0}, + {`trace datadog http://localhost:8127`, false, "http://localhost:8127", 1, `coredns`, false, 0, 0, 0}, + {"trace datadog localhost {\n datadog_analytics_rate 0.1\n}", false, "localhost", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n every 100\n}", false, "http://localhost:9411/api/v2/spans", 100, `coredns`, false, 0, 0, 0}, + {"trace {\n every 100\n service foobar\nclient_server\n}", false, "http://localhost:9411/api/v2/spans", 100, `foobar`, true, 0, 0, 0}, + {"trace {\n every 2\n client_server true\n}", false, "http://localhost:9411/api/v2/spans", 2, `coredns`, true, 0, 0, 0}, + {"trace {\n client_server false\n}", false, "http://localhost:9411/api/v2/spans", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n zipkin_max_backlog_size 100\n zipkin_max_batch_size 200\n zipkin_max_batch_interval 10s\n}", false, + "http://localhost:9411/api/v2/spans", 1, `coredns`, false, 100, 200, 10 * time.Second}, + + // fails + {`trace footype localhost:4321`, true, "", 1, "", false, 0, 0, 0}, + {"trace {\n every 2\n client_server junk\n}", true, "", 1, "", false, 0, 0, 0}, + {"trace datadog localhost {\n datadog_analytics_rate 2\n}", true, "", 1, "", false, 0, 0, 0}, + {"trace {\n zipkin_max_backlog_size wrong\n}", true, "", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n zipkin_max_batch_size wrong\n}", true, "", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n zipkin_max_batch_interval wrong\n}", true, "", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n zipkin_max_backlog_size\n}", true, "", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n zipkin_max_batch_size\n}", true, "", 1, `coredns`, false, 0, 0, 0}, + {"trace {\n zipkin_max_batch_interval\n}", true, "", 1, `coredns`, false, 0, 0, 0}, + } + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + m, err := traceParse(c) + if test.shouldErr && err == nil { + t.Errorf("Test %v: Expected error but found nil", i) + continue + } else if !test.shouldErr && err != nil { + t.Errorf("Test %v: Expected no error but found error: %v", i, err) + continue + } + + if test.shouldErr { + continue + } + + if "" != m.serviceEndpoint { + t.Errorf("Test %v: Expected serviceEndpoint to be '' but found: %s", i, m.serviceEndpoint) + } + if test.endpoint != m.Endpoint { + t.Errorf("Test %v: Expected endpoint %s but found: %s", i, test.endpoint, m.Endpoint) + } + if test.every != m.every { + t.Errorf("Test %v: Expected every %d but found: %d", i, test.every, m.every) + } + if test.serviceName != m.serviceName { + t.Errorf("Test %v: Expected service name %s but found: %s", i, test.serviceName, m.serviceName) + } + if test.clientServer != m.clientServer { + t.Errorf("Test %v: Expected client_server %t but found: %t", i, test.clientServer, m.clientServer) + } + if test.zipkinMaxBacklogSize != m.zipkinMaxBacklogSize { + t.Errorf("Test %v: Expected zipkin_max_backlog_size %d but found: %d", i, test.zipkinMaxBacklogSize, m.zipkinMaxBacklogSize) + } + if test.zipkinMaxBatchSize != m.zipkinMaxBatchSize { + t.Errorf("Test %v: Expected zipkin_max_batch_size %d but found: %d", i, test.zipkinMaxBatchSize, m.zipkinMaxBatchSize) + } + if test.zipkinMaxBatchInterval != m.zipkinMaxBatchInterval { + t.Errorf("Test %v: Expected zipkin_max_batch_interval %v but found: %v", i, test.zipkinMaxBatchInterval, m.zipkinMaxBatchInterval) + } + } +} diff --git a/plugin/trace/trace.go b/plugin/trace/trace.go new file mode 100644 index 0000000..f740967 --- /dev/null +++ b/plugin/trace/trace.go @@ -0,0 +1,204 @@ +// Package trace implements OpenTracing-based tracing +package trace + +import ( + "context" + "fmt" + stdlog "log" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/plugin/pkg/dnstest" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/plugin/pkg/rcode" + _ "github.com/coredns/coredns/plugin/pkg/trace" // Plugin the trace package. + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" + otext "github.com/opentracing/opentracing-go/ext" + otlog "github.com/opentracing/opentracing-go/log" + zipkinot "github.com/openzipkin-contrib/zipkin-go-opentracing" + "github.com/openzipkin/zipkin-go" + zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" +) + +const ( + defaultTopLevelSpanName = "servedns" + metaTraceIdKey = "trace/traceid" +) + +var log = clog.NewWithPlugin("trace") + +type traceTags struct { + Name string + Type string + Rcode string + Proto string + Remote string +} + +var tagByProvider = map[string]traceTags{ + "default": { + Name: "coredns.io/name", + Type: "coredns.io/type", + Rcode: "coredns.io/rcode", + Proto: "coredns.io/proto", + Remote: "coredns.io/remote", + }, + "datadog": { + Name: "coredns.io@name", + Type: "coredns.io@type", + Rcode: "coredns.io@rcode", + Proto: "coredns.io@proto", + Remote: "coredns.io@remote", + }, +} + +type trace struct { + count uint64 // as per Go spec, needs to be first element in a struct + + Next plugin.Handler + Endpoint string + EndpointType string + tracer ot.Tracer + serviceEndpoint string + serviceName string + clientServer bool + every uint64 + datadogAnalyticsRate float64 + zipkinMaxBacklogSize int + zipkinMaxBatchSize int + zipkinMaxBatchInterval time.Duration + Once sync.Once + tagSet traceTags +} + +func (t *trace) Tracer() ot.Tracer { + return t.tracer +} + +// OnStartup sets up the tracer +func (t *trace) OnStartup() error { + var err error + t.Once.Do(func() { + switch t.EndpointType { + case "zipkin": + err = t.setupZipkin() + case "datadog": + tracer := opentracer.New( + tracer.WithAgentAddr(t.Endpoint), + tracer.WithDebugMode(clog.D.Value()), + tracer.WithGlobalTag(ext.SpanTypeDNS, true), + tracer.WithServiceName(t.serviceName), + tracer.WithAnalyticsRate(t.datadogAnalyticsRate), + tracer.WithLogger(&loggerAdapter{log}), + ) + t.tracer = tracer + t.tagSet = tagByProvider["datadog"] + default: + err = fmt.Errorf("unknown endpoint type: %s", t.EndpointType) + } + }) + return err +} + +func (t *trace) setupZipkin() error { + var opts []zipkinhttp.ReporterOption + opts = append(opts, zipkinhttp.Logger(stdlog.New(&loggerAdapter{log}, "", 0))) + if t.zipkinMaxBacklogSize != 0 { + opts = append(opts, zipkinhttp.MaxBacklog(t.zipkinMaxBacklogSize)) + } + if t.zipkinMaxBatchSize != 0 { + opts = append(opts, zipkinhttp.BatchSize(t.zipkinMaxBatchSize)) + } + if t.zipkinMaxBatchInterval != 0 { + opts = append(opts, zipkinhttp.BatchInterval(t.zipkinMaxBatchInterval)) + } + reporter := zipkinhttp.NewReporter(t.Endpoint, opts...) + recorder, err := zipkin.NewEndpoint(t.serviceName, t.serviceEndpoint) + if err != nil { + log.Warningf("build Zipkin endpoint found err: %v", err) + } + tracer, err := zipkin.NewTracer( + reporter, + zipkin.WithLocalEndpoint(recorder), + zipkin.WithSharedSpans(t.clientServer), + ) + if err != nil { + return err + } + t.tracer = zipkinot.Wrap(tracer) + + t.tagSet = tagByProvider["default"] + return err +} + +// Name implements the Handler interface. +func (t *trace) Name() string { return "trace" } + +// ServeDNS implements the plugin.Handle interface. +func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + trace := false + if t.every > 0 { + queryNr := atomic.AddUint64(&t.count, 1) + + if queryNr%t.every == 0 { + trace = true + } + } + span := ot.SpanFromContext(ctx) + if !trace || span != nil { + return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + } + + var spanCtx ot.SpanContext + if val := ctx.Value(dnsserver.HTTPRequestKey{}); val != nil { + if httpReq, ok := val.(*http.Request); ok { + spanCtx, _ = t.Tracer().Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(httpReq.Header)) + } + } + + req := request.Request{W: w, Req: r} + span = t.Tracer().StartSpan(defaultTopLevelSpanName, otext.RPCServerOption(spanCtx)) + defer span.Finish() + + switch spanCtx := span.Context().(type) { + case zipkinot.SpanContext: + metadata.SetValueFunc(ctx, metaTraceIdKey, func() string { return spanCtx.TraceID.String() }) + case ddtrace.SpanContext: + metadata.SetValueFunc(ctx, metaTraceIdKey, func() string { return fmt.Sprint(spanCtx.TraceID()) }) + } + + rw := dnstest.NewRecorder(w) + ctx = ot.ContextWithSpan(ctx, span) + status, err := plugin.NextOrFailure(t.Name(), t.Next, ctx, rw, r) + + span.SetTag(t.tagSet.Name, req.Name()) + span.SetTag(t.tagSet.Type, req.Type()) + span.SetTag(t.tagSet.Proto, req.Proto()) + span.SetTag(t.tagSet.Remote, req.IP()) + rc := rw.Rcode + if !plugin.ClientWrite(status) { + // when no response was written, fallback to status returned from next plugin as this status + // is actually used as rcode of DNS response + // see https://github.com/coredns/coredns/blob/master/core/dnsserver/server.go#L318 + rc = status + } + span.SetTag(t.tagSet.Rcode, rcode.ToString(rc)) + if err != nil { + otext.Error.Set(span, true) + span.LogFields(otlog.Event("error"), otlog.Error(err)) + } + + return status, err +} diff --git a/plugin/trace/trace_test.go b/plugin/trace/trace_test.go new file mode 100644 index 0000000..c260501 --- /dev/null +++ b/plugin/trace/trace_test.go @@ -0,0 +1,174 @@ +package trace + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/rcode" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" +) + +func TestStartup(t *testing.T) { + m, err := traceParse(caddy.NewTestController("dns", `trace`)) + if err != nil { + t.Errorf("Error parsing test input: %s", err) + return + } + if m.Name() != "trace" { + t.Errorf("Wrong name from GetName: %s", m.Name()) + } + err = m.OnStartup() + if err != nil { + t.Errorf("Error starting tracing plugin: %s", err) + return + } + + if m.tagSet != tagByProvider["default"] { + t.Errorf("TagSet by proviser hasn't been corectly initialized") + } + + if m.Tracer() == nil { + t.Errorf("Error, no tracer created") + } +} + +func TestTrace(t *testing.T) { + cases := []struct { + name string + rcode int + status int + question *dns.Msg + err error + }{ + { + name: "NXDOMAIN", + rcode: dns.RcodeNameError, + status: dns.RcodeSuccess, + question: new(dns.Msg).SetQuestion("example.org.", dns.TypeA), + }, + { + name: "NOERROR", + rcode: dns.RcodeSuccess, + status: dns.RcodeSuccess, + question: new(dns.Msg).SetQuestion("example.net.", dns.TypeCNAME), + }, + { + name: "SERVFAIL", + rcode: dns.RcodeServerFailure, + status: dns.RcodeSuccess, + question: new(dns.Msg).SetQuestion("example.net.", dns.TypeA), + err: errors.New("test error"), + }, + { + name: "No response written", + rcode: dns.RcodeServerFailure, + status: dns.RcodeServerFailure, + question: new(dns.Msg).SetQuestion("example.net.", dns.TypeA), + err: errors.New("test error"), + }, + } + defaultTagSet := tagByProvider["default"] + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + m := mocktracer.New() + tr := &trace{ + Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if plugin.ClientWrite(tc.status) { + m := new(dns.Msg) + m.SetRcode(r, tc.rcode) + w.WriteMsg(m) + } + return tc.status, tc.err + }), + every: 1, + tracer: m, + tagSet: defaultTagSet, + } + ctx := context.TODO() + if _, err := tr.ServeDNS(ctx, w, tc.question); err != nil && tc.err == nil { + t.Fatalf("Error during tr.ServeDNS(ctx, w, %v): %v", tc.question, err) + } + + fs := m.FinishedSpans() + // Each trace consists of two spans; the root and the Next function. + if len(fs) != 2 { + t.Fatalf("Unexpected span count: len(fs): want 2, got %v", len(fs)) + } + + rootSpan := fs[1] + req := request.Request{W: w, Req: tc.question} + if rootSpan.OperationName != defaultTopLevelSpanName { + t.Errorf("Unexpected span name: rootSpan.Name: want %v, got %v", defaultTopLevelSpanName, rootSpan.OperationName) + } + + if rootSpan.Tag(defaultTagSet.Name) != req.Name() { + t.Errorf("Unexpected span tag: rootSpan.Tag(%v): want %v, got %v", defaultTagSet.Name, req.Name(), rootSpan.Tag(defaultTagSet.Name)) + } + if rootSpan.Tag(defaultTagSet.Type) != req.Type() { + t.Errorf("Unexpected span tag: rootSpan.Tag(%v): want %v, got %v", defaultTagSet.Type, req.Type(), rootSpan.Tag(defaultTagSet.Type)) + } + if rootSpan.Tag(defaultTagSet.Proto) != req.Proto() { + t.Errorf("Unexpected span tag: rootSpan.Tag(%v): want %v, got %v", defaultTagSet.Proto, req.Proto(), rootSpan.Tag(defaultTagSet.Proto)) + } + if rootSpan.Tag(defaultTagSet.Remote) != req.IP() { + t.Errorf("Unexpected span tag: rootSpan.Tag(%v): want %v, got %v", defaultTagSet.Remote, req.IP(), rootSpan.Tag(defaultTagSet.Remote)) + } + if rootSpan.Tag(defaultTagSet.Rcode) != rcode.ToString(tc.rcode) { + t.Errorf("Unexpected span tag: rootSpan.Tag(%v): want %v, got %v", defaultTagSet.Rcode, rcode.ToString(tc.rcode), rootSpan.Tag(defaultTagSet.Rcode)) + } + if tc.err != nil && rootSpan.Tag("error") != true { + t.Errorf("Unexpected span tag: rootSpan.Tag(%v): want %v, got %v", "error", true, rootSpan.Tag("error")) + } + }) + } +} + +func TestTrace_DOH_TraceHeaderExtraction(t *testing.T) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + m := mocktracer.New() + tr := &trace{ + Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if plugin.ClientWrite(dns.RcodeSuccess) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeSuccess) + w.WriteMsg(m) + } + return dns.RcodeSuccess, nil + }), + every: 1, + tracer: m, + } + q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA) + + req := httptest.NewRequest(http.MethodPost, "/dns-query", nil) + + outsideSpan := m.StartSpan("test-header-span") + outsideSpan.Tracer().Inject(outsideSpan.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) + defer outsideSpan.Finish() + + ctx := context.TODO() + ctx = context.WithValue(ctx, dnsserver.HTTPRequestKey{}, req) + + tr.ServeDNS(ctx, w, q) + + fs := m.FinishedSpans() + rootCoreDNSspan := fs[1] + rootCoreDNSTraceID := rootCoreDNSspan.Context().(mocktracer.MockSpanContext).TraceID + outsideSpanTraceID := outsideSpan.Context().(mocktracer.MockSpanContext).TraceID + if rootCoreDNSTraceID != outsideSpanTraceID { + t.Errorf("Unexpected traceID: rootSpan.TraceID: want %v, got %v", rootCoreDNSTraceID, outsideSpanTraceID) + } +} diff --git a/plugin/transfer/README.md b/plugin/transfer/README.md new file mode 100644 index 0000000..43c1623 --- /dev/null +++ b/plugin/transfer/README.md @@ -0,0 +1,59 @@ +# transfer + +## Name + +*transfer* - perform (outgoing) zone transfers for other plugins. + +## Description + +This plugin answers zone transfers for authoritative plugins that implement `transfer.Transferer`. + +*transfer* answers full zone transfer (AXFR) requests and incremental zone transfer (IXFR) requests +with AXFR fallback if the zone has changed. + +When a plugin wants to notify it's secondaries it will call back into the *transfer* plugin. + +The following plugins implement zone transfers using this plugin: *file*, *auto*, *secondary*, and +*kubernetes*. See `transfer.go` for implementation details if you are a plugin author that wants to +use this plugin. + +## Syntax + +~~~ +transfer [ZONE...] { + to ADDRESS... +} +~~~ + + * **ZONE** The zones *transfer* will answer zone transfer requests for. If left blank, the zones + are inherited from the enclosing server block. To answer zone transfers for a given zone, + there must be another plugin in the same server block that serves the same zone, and implements + `transfer.Transferer`. + + * `to` **ADDRESS...** The hosts *transfer* will transfer to. Use `*` to permit transfers to all + addresses. Zone change notifications are sent to all **ADDRESS** that are an IP address or + an IP address and port e.g. `1.2.3.4`, `12:34::56`, `1.2.3.4:5300`, `[12:34::56]:5300`. + `to` may be specified multiple times. + +You can use the _acl_ plugin to further restrict hosts permitted to receive a zone transfer. +See example below. + +## Examples + +Use in conjunction with the _acl_ plugin to restrict access to subnet 10.1.0.0/16. + +``` +... + acl { + allow type AXFR net 10.1.0.0/16 + allow type IXFR net 10.1.0.0/16 + block type AXFR net * + block type IXFR net * + } + transfer { + to * + } +... +``` + +Each plugin that can use _transfer_ includes an example of use in their respective documentation.
\ No newline at end of file diff --git a/plugin/transfer/failed_write_test.go b/plugin/transfer/failed_write_test.go new file mode 100644 index 0000000..e60fd50 --- /dev/null +++ b/plugin/transfer/failed_write_test.go @@ -0,0 +1,30 @@ +package transfer + +import ( + "context" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +type badwriter struct { + dns.ResponseWriter +} + +func (w *badwriter) WriteMsg(_ *dns.Msg) error { return fmt.Errorf("failed to write msg") } + +func TestWriteMessageFailed(t *testing.T) { + transfer := newTestTransfer() + ctx := context.TODO() + w := &badwriter{ResponseWriter: &test.ResponseWriter{TCP: true}} + m := &dns.Msg{} + m.SetAxfr("example.org.") + + _, err := transfer.ServeDNS(ctx, w, m) + if err == nil { + t.Error("Expected error, got none") + } +} diff --git a/plugin/transfer/notify.go b/plugin/transfer/notify.go new file mode 100644 index 0000000..26f7668 --- /dev/null +++ b/plugin/transfer/notify.go @@ -0,0 +1,58 @@ +package transfer + +import ( + "fmt" + + "github.com/coredns/coredns/plugin/pkg/rcode" + + "github.com/miekg/dns" +) + +// Notify will send notifies to all configured to hosts IP addresses. The string zone must be lowercased. +func (t *Transfer) Notify(zone string) error { + if t == nil { // t might be nil, mostly expected in tests, so intercept and to a noop in that case + return nil + } + + m := new(dns.Msg) + m.SetNotify(zone) + c := new(dns.Client) + + x := longestMatch(t.xfrs, zone) + if x == nil { + // return without error if there is no matching zone + return nil + } + + var err1 error + for _, t := range x.to { + if t == "*" { + continue + } + if err := sendNotify(c, m, t); err != nil { + err1 = err + } + } + log.Debugf("Sent notifies for zone %q to %v", zone, x.to) + return err1 // this only captures the last error +} + +func sendNotify(c *dns.Client, m *dns.Msg, s string) error { + var err error + + code := dns.RcodeServerFailure + for i := 0; i < 3; i++ { + ret, _, err := c.Exchange(m, s) + if err != nil { + continue + } + code = ret.Rcode + if code == dns.RcodeSuccess { + return nil + } + } + if err != nil { + return fmt.Errorf("notify for zone %q was not accepted by %q: %q", m.Question[0].Name, s, err) + } + return fmt.Errorf("notify for zone %q was not accepted by %q: rcode was %q", m.Question[0].Name, s, rcode.ToString(code)) +} diff --git a/plugin/transfer/select_test.go b/plugin/transfer/select_test.go new file mode 100644 index 0000000..a064b00 --- /dev/null +++ b/plugin/transfer/select_test.go @@ -0,0 +1,58 @@ +package transfer + +import ( + "context" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +type ( + t1 struct{} + t2 struct{} +) + +func (t t1) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + const z = "example.org." + if zone != z { + return nil, ErrNotAuthoritative + } + return nil, fmt.Errorf(z) +} +func (t t2) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + const z = "sub.example.org." + if zone != z { + return nil, ErrNotAuthoritative + } + return nil, fmt.Errorf(z) +} + +func TestZoneSelection(t *testing.T) { + tr := &Transfer{ + Transferers: []Transferer{t1{}, t2{}}, + xfrs: []*xfr{ + { + Zones: []string{"example.org."}, + to: []string{"192.0.2.1"}, // RFC 5737 IP, no interface should have this address. + }, + { + Zones: []string{"sub.example.org."}, + to: []string{"*"}, + }, + }, + } + r := new(dns.Msg) + r.SetAxfr("sub.example.org.") + w := dnstest.NewRecorder(&test.ResponseWriter{TCP: true}) + _, err := tr.ServeDNS(context.TODO(), w, r) + if err == nil { + t.Fatal("Expected error, got nil") + } + if x := err.Error(); x != "sub.example.org." { + t.Errorf("Expected transfer for zone %s, got %s", "sub.example.org", x) + } +} diff --git a/plugin/transfer/setup.go b/plugin/transfer/setup.go new file mode 100644 index 0000000..cd7d209 --- /dev/null +++ b/plugin/transfer/setup.go @@ -0,0 +1,81 @@ +package transfer + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/parse" + "github.com/coredns/coredns/plugin/pkg/transport" +) + +func init() { + caddy.RegisterPlugin("transfer", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + t, err := parseTransfer(c) + + if err != nil { + return plugin.Error("transfer", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + t.Next = next + return t + }) + + c.OnStartup(func() error { + config := dnsserver.GetConfig(c) + t.tsigSecret = config.TsigSecret + // find all plugins that implement Transferer and add them to Transferers + plugins := config.Handlers() + for _, pl := range plugins { + tr, ok := pl.(Transferer) + if !ok { + continue + } + t.Transferers = append(t.Transferers, tr) + } + return nil + }) + + return nil +} + +func parseTransfer(c *caddy.Controller) (*Transfer, error) { + t := &Transfer{} + for c.Next() { + x := &xfr{} + x.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + for c.NextBlock() { + switch c.Val() { + case "to": + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + for _, host := range args { + if host == "*" { + x.to = append(x.to, host) + continue + } + normalized, err := parse.HostPort(host, transport.Port) + if err != nil { + return nil, err + } + x.to = append(x.to, normalized) + } + default: + return nil, plugin.Error("transfer", c.Errf("unknown property %q", c.Val())) + } + } + if len(x.to) == 0 { + return nil, plugin.Error("transfer", c.Err("'to' is required")) + } + t.xfrs = append(t.xfrs, x) + } + return t, nil +} diff --git a/plugin/transfer/setup_test.go b/plugin/transfer/setup_test.go new file mode 100644 index 0000000..ebfe99c --- /dev/null +++ b/plugin/transfer/setup_test.go @@ -0,0 +1,131 @@ +package transfer + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestParse(t *testing.T) { + tests := []struct { + input string + zones []string + shouldErr bool + exp *Transfer + }{ + {`transfer example.net example.org { + to 1.2.3.4 5.6.7.8:1053 [1::2]:34 + } + transfer example.com example.edu { + to * 1.2.3.4 + }`, + nil, + false, + &Transfer{ + xfrs: []*xfr{{ + Zones: []string{"example.net.", "example.org."}, + to: []string{"1.2.3.4:53", "5.6.7.8:1053", "[1::2]:34"}, + }, { + Zones: []string{"example.com.", "example.edu."}, + to: []string{"*", "1.2.3.4:53"}, + }}, + }, + }, + // errors + {`transfer example.net example.org { + }`, + nil, + true, + nil, + }, + {`transfer example.net example.org { + invalid option + }`, + nil, + true, + nil, + }, + { + ` + transfer example.com example.edu { + to example.com 1.2.3.4 + }`, + nil, + true, + nil, + }, + { + `transfer { + to 1.2.3.4 5.6.7.8:1053 [1::2]:34 + }`, + []string{"."}, + false, + &Transfer{ + xfrs: []*xfr{{ + Zones: []string{"."}, + to: []string{"1.2.3.4:53", "5.6.7.8:1053", "[1::2]:34"}, + }}, + }, + }, + } + for i, tc := range tests { + c := caddy.NewTestController("dns", tc.input) + c.ServerBlockKeys = append(c.ServerBlockKeys, tc.zones...) + + transfer, err := parseTransfer(c) + + if err == nil && tc.shouldErr { + t.Fatalf("Test %d expected errors, but got no error", i) + } + if err != nil && !tc.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } + if tc.exp == nil && transfer != nil { + t.Fatalf("Test %d expected %v xfrs, got %#v", i, tc.exp, transfer) + } + if tc.shouldErr { + continue + } + + if len(tc.exp.xfrs) != len(transfer.xfrs) { + t.Fatalf("Test %d expected %d xfrs, got %d", i, len(tc.exp.xfrs), len(transfer.xfrs)) + } + for j, x := range transfer.xfrs { + // Check Zones + if len(tc.exp.xfrs[j].Zones) != len(x.Zones) { + t.Fatalf("Test %d expected %d zones, got %d", i, len(tc.exp.xfrs[i].Zones), len(x.Zones)) + } + for k, zone := range x.Zones { + if tc.exp.xfrs[j].Zones[k] != zone { + t.Errorf("Test %d expected zone %v, got %v", i, tc.exp.xfrs[j].Zones[k], zone) + } + } + // Check to + if len(tc.exp.xfrs[j].to) != len(x.to) { + t.Fatalf("Test %d expected %d 'to' values, got %d", i, len(tc.exp.xfrs[i].to), len(x.to)) + } + for k, to := range x.to { + if tc.exp.xfrs[j].to[k] != to { + t.Errorf("Test %d expected %v in 'to', got %v", i, tc.exp.xfrs[j].to[k], to) + } + } + } + } +} + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", "transfer") + if err := setup(c); err == nil { + t.Fatal("Expected errors, but got nil") + } + + c = caddy.NewTestController("dns", `transfer example.net example.org { + to 1.2.3.4 5.6.7.8:1053 [1::2]:34 + } + transfer example.com example.edu { + to * 1.2.3.4 + }`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got %v", err) + } +} diff --git a/plugin/transfer/transfer.go b/plugin/transfer/transfer.go new file mode 100644 index 0000000..f0b42e9 --- /dev/null +++ b/plugin/transfer/transfer.go @@ -0,0 +1,221 @@ +package transfer + +import ( + "context" + "errors" + "net" + + "github.com/coredns/coredns/plugin" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +var log = clog.NewWithPlugin("transfer") + +// Transfer is a plugin that handles zone transfers. +type Transfer struct { + Transferers []Transferer // List of plugins that implement Transferer + xfrs []*xfr + tsigSecret map[string]string + Next plugin.Handler +} + +type xfr struct { + Zones []string + to []string +} + +// Transferer may be implemented by plugins to enable zone transfers +type Transferer interface { + // Transfer returns a channel to which it writes responses to the transfer request. + // If the plugin is not authoritative for the zone, it should immediately return the + // transfer.ErrNotAuthoritative error. This is important otherwise the transfer plugin can + // use plugin X while it should transfer the data from plugin Y. + // + // If serial is 0, handle as an AXFR request. Transfer should send all records + // in the zone to the channel. The SOA should be written to the channel first, followed + // by all other records, including all NS + glue records. The implementation is also responsible + // for sending the last SOA record (to signal end of the transfer). This plugin will just grab + // these records and send them back to the requester, there is little validation done. + // + // If serial is not 0, it will be handled as an IXFR request. If the serial is equal to or greater (newer) than + // the current serial for the zone, send a single SOA record to the channel and then close it. + // If the serial is less (older) than the current serial for the zone, perform an AXFR fallback + // by proceeding as if an AXFR was requested (as above). + Transfer(zone string, serial uint32) (<-chan []dns.RR, error) +} + +var ( + // ErrNotAuthoritative is returned by Transfer() when the plugin is not authoritative for the zone. + ErrNotAuthoritative = errors.New("not authoritative for zone") +) + +// ServeDNS implements the plugin.Handler interface. +func (t *Transfer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + if state.QType() != dns.TypeAXFR && state.QType() != dns.TypeIXFR { + return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + } + + if state.Proto() != "tcp" { + return dns.RcodeRefused, nil + } + + x := longestMatch(t.xfrs, state.QName()) + if x == nil { + return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + } + + if !x.allowed(state) { + // write msg here, so logging will pick it up + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeRefused) + w.WriteMsg(m) + return 0, nil + } + + // Get serial from request if this is an IXFR. + var serial uint32 + if state.QType() == dns.TypeIXFR { + if len(r.Ns) != 1 { + return dns.RcodeServerFailure, nil + } + soa, ok := r.Ns[0].(*dns.SOA) + if !ok { + return dns.RcodeServerFailure, nil + } + serial = soa.Serial + } + + // Get a receiving channel from the first Transferer plugin that returns one. + var pchan <-chan []dns.RR + var err error + for _, p := range t.Transferers { + pchan, err = p.Transfer(state.QName(), serial) + if err == ErrNotAuthoritative { + // plugin was not authoritative for the zone, try next plugin + continue + } + if err != nil { + return dns.RcodeServerFailure, err + } + break + } + + if pchan == nil { + return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + } + + // Send response to client + ch := make(chan *dns.Envelope) + tr := new(dns.Transfer) + if r.IsTsig() != nil { + tr.TsigSecret = t.tsigSecret + } + errCh := make(chan error) + go func() { + if err := tr.Out(w, r, ch); err != nil { + errCh <- err + } + close(errCh) + }() + + rrs := []dns.RR{} + l := 0 + var soa *dns.SOA + for records := range pchan { + if x, ok := records[0].(*dns.SOA); ok && soa == nil { + soa = x + } + rrs = append(rrs, records...) + if len(rrs) > 500 { + select { + case ch <- &dns.Envelope{RR: rrs}: + case err := <-errCh: + return dns.RcodeServerFailure, err + } + l += len(rrs) + rrs = []dns.RR{} + } + } + + // if we are here and we only hold 1 soa (len(rrs) == 1) and soa != nil, and IXFR fallback should + // be performed. We haven't send anything on ch yet, so that can be closed (and waited for), and we only + // need to return the SOA back to the client and return. + if len(rrs) == 1 && soa != nil { // soa should never be nil... + close(ch) + err := <-errCh + if err != nil { + return dns.RcodeServerFailure, err + } + + m := new(dns.Msg) + m.SetReply(r) + m.Answer = []dns.RR{soa} + w.WriteMsg(m) + + log.Infof("Outgoing noop, incremental transfer for up to date zone %q to %s for %d SOA serial", state.QName(), state.IP(), soa.Serial) + return 0, nil + } + + if len(rrs) > 0 { + select { + case ch <- &dns.Envelope{RR: rrs}: + case err := <-errCh: + return dns.RcodeServerFailure, err + } + l += len(rrs) + } + + close(ch) // Even though we close the channel here, we still have + err = <-errCh // to wait before we can return and close the connection. + if err != nil { + return dns.RcodeServerFailure, err + } + + logserial := uint32(0) + if soa != nil { + logserial = soa.Serial + } + log.Infof("Outgoing transfer of %d records of zone %q to %s for %d SOA serial", l, state.QName(), state.IP(), logserial) + return 0, nil +} + +func (x xfr) allowed(state request.Request) bool { + for _, h := range x.to { + if h == "*" { + return true + } + to, _, err := net.SplitHostPort(h) + if err != nil { + return false + } + // If remote IP matches we accept. TODO(): make this works with ranges + if to == state.IP() { + return true + } + } + return false +} + +// Find the first transfer instance for which the queried zone is the longest match. When nothing +// is found nil is returned. +func longestMatch(xfrs []*xfr, name string) *xfr { + // TODO(xxx): optimize and make it a map (or maps) + var x *xfr + zone := "" // longest zone match wins + for _, xfr := range xfrs { + if z := plugin.Zones(xfr.Zones).Matches(name); z != "" { + if z > zone { + zone = z + x = xfr + } + } + } + return x +} + +// Name implements the Handler interface. +func (Transfer) Name() string { return "transfer" } diff --git a/plugin/transfer/transfer_test.go b/plugin/transfer/transfer_test.go new file mode 100644 index 0000000..79233d1 --- /dev/null +++ b/plugin/transfer/transfer_test.go @@ -0,0 +1,278 @@ +package transfer + +import ( + "context" + "fmt" + "testing" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +// transfererPlugin implements transfer.Transferer and plugin.Handler. +type transfererPlugin struct { + Zone string + Serial uint32 + Next plugin.Handler +} + +// Name implements plugin.Handler. +func (*transfererPlugin) Name() string { return "transfererplugin" } + +// ServeDNS implements plugin.Handler. +func (p *transfererPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if r.Question[0].Name != p.Zone { + return p.Next.ServeDNS(ctx, w, r) + } + return 0, nil +} + +// Transfer implements transfer.Transferer - it returns a static AXFR response, or +// if serial is current, an abbreviated IXFR response. +func (p *transfererPlugin) Transfer(zone string, serial uint32) (<-chan []dns.RR, error) { + if zone != p.Zone { + return nil, ErrNotAuthoritative + } + ch := make(chan []dns.RR, 3) // sending 3 bits and don't want to block, nor do a waitgroup + defer close(ch) + ch <- []dns.RR{test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s %d 7200 1800 86400 100", p.Zone, p.Zone, p.Zone, p.Serial))} + if serial >= p.Serial { + return ch, nil + } + ch <- []dns.RR{ + test.NS(fmt.Sprintf("%s 100 IN NS ns.dns.%s", p.Zone, p.Zone)), + test.A(fmt.Sprintf("ns.dns.%s 100 IN A 1.2.3.4", p.Zone)), + } + ch <- []dns.RR{test.SOA(fmt.Sprintf("%s 100 IN SOA ns.dns.%s hostmaster.%s %d 7200 1800 86400 100", p.Zone, p.Zone, p.Zone, p.Serial))} + return ch, nil +} + +type terminatingPlugin struct{} + +// Name implements plugin.Handler. +func (*terminatingPlugin) Name() string { return "testplugin" } + +// ServeDNS implements plugin.Handler that returns NXDOMAIN for all requests. +func (*terminatingPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeNameError) + w.WriteMsg(m) + return dns.RcodeNameError, nil +} + +func newTestTransfer() *Transfer { + nextPlugin1 := transfererPlugin{Zone: "example.com.", Serial: 12345} + nextPlugin2 := transfererPlugin{Zone: "example.org.", Serial: 12345} + nextPlugin2.Next = &terminatingPlugin{} + nextPlugin1.Next = &nextPlugin2 + + transfer := &Transfer{ + Transferers: []Transferer{&nextPlugin1, &nextPlugin2}, + xfrs: []*xfr{ + { + Zones: []string{"example.org."}, + to: []string{"*"}, + }, + { + Zones: []string{"example.com."}, + to: []string{"*"}, + }, + }, + Next: &nextPlugin1, + } + return transfer +} + +func TestTransferNonZone(t *testing.T) { + transfer := newTestTransfer() + ctx := context.TODO() + + for _, tc := range []string{"sub.example.org.", "example.test."} { + w := dnstest.NewRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetAxfr(tc) + + _, err := transfer.ServeDNS(ctx, w, m) + if err != nil { + t.Error(err) + } + + if w.Msg == nil { + t.Fatalf("Got nil message for AXFR %s", tc) + } + + if w.Msg.Rcode != dns.RcodeNameError { + t.Errorf("Expected NXDOMAIN for AXFR %s got %s", tc, dns.RcodeToString[w.Msg.Rcode]) + } + } +} + +func TestTransferNotAXFRorIXFR(t *testing.T) { + transfer := newTestTransfer() + + ctx := context.TODO() + w := dnstest.NewRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetQuestion("test.domain.", dns.TypeA) + + _, err := transfer.ServeDNS(ctx, w, m) + if err != nil { + t.Error(err) + } + + if w.Msg == nil { + t.Fatal("Got nil message") + } + + if w.Msg.Rcode != dns.RcodeNameError { + t.Errorf("Expected NXDOMAIN got %s", dns.RcodeToString[w.Msg.Rcode]) + } +} + +func TestTransferAXFRExampleOrg(t *testing.T) { + transfer := newTestTransfer() + + ctx := context.TODO() + w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetAxfr(transfer.xfrs[0].Zones[0]) + + _, err := transfer.ServeDNS(ctx, w, m) + if err != nil { + t.Error(err) + } + + validateAXFRResponse(t, w) +} + +func TestTransferAXFRExampleCom(t *testing.T) { + transfer := newTestTransfer() + + ctx := context.TODO() + w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetAxfr(transfer.xfrs[1].Zones[0]) + + _, err := transfer.ServeDNS(ctx, w, m) + if err != nil { + t.Error(err) + } + + validateAXFRResponse(t, w) +} + +func TestTransferIXFRCurrent(t *testing.T) { + transfer := newTestTransfer() + + testPlugin := transfer.Transferers[0].(*transfererPlugin) + + ctx := context.TODO() + w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetIxfr(transfer.xfrs[0].Zones[0], testPlugin.Serial, "ns.dns."+testPlugin.Zone, "hostmaster.dns."+testPlugin.Zone) + + _, err := transfer.ServeDNS(ctx, w, m) + if err != nil { + t.Error(err) + } + + if len(w.Msgs) == 0 { + t.Fatal("Did not get back a zone response") + } + + if len(w.Msgs[0].Answer) != 1 { + t.Logf("%+v\n", w) + t.Fatalf("Expected 1 answer, got %d", len(w.Msgs[0].Answer)) + } + + // Ensure the answer is the SOA + if w.Msgs[0].Answer[0].Header().Rrtype != dns.TypeSOA { + t.Error("Answer does not contain the SOA record") + } +} + +func TestTransferIXFRFallback(t *testing.T) { + transfer := newTestTransfer() + + testPlugin := transfer.Transferers[0].(*transfererPlugin) + + ctx := context.TODO() + w := dnstest.NewMultiRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetIxfr( + transfer.xfrs[0].Zones[0], + testPlugin.Serial-1, + "ns.dns."+testPlugin.Zone, + "hostmaster.dns."+testPlugin.Zone, + ) + + _, err := transfer.ServeDNS(ctx, w, m) + if err != nil { + t.Error(err) + } + + validateAXFRResponse(t, w) +} + +func validateAXFRResponse(t *testing.T, w *dnstest.MultiRecorder) { + if len(w.Msgs) == 0 { + t.Fatal("Did not get back a zone response") + } + + if len(w.Msgs[0].Answer) == 0 { + t.Logf("%+v\n", w) + t.Fatal("Did not get back an answer") + } + + // Ensure the answer starts with SOA + if w.Msgs[0].Answer[0].Header().Rrtype != dns.TypeSOA { + t.Error("Answer does not start with SOA record") + } + + // Ensure the answer ends with SOA + if w.Msgs[len(w.Msgs)-1].Answer[len(w.Msgs[len(w.Msgs)-1].Answer)-1].Header().Rrtype != dns.TypeSOA { + t.Error("Answer does not end with SOA record") + } + + // Ensure the answer is the expected length + c := 0 + for _, m := range w.Msgs { + c += len(m.Answer) + } + if c != 4 { + t.Errorf("Answer is not the expected length (expected 4, got %d)", c) + } +} + +func TestTransferNotAllowed(t *testing.T) { + nextPlugin := transfererPlugin{Zone: "example.org.", Serial: 12345} + + transfer := Transfer{ + Transferers: []Transferer{&nextPlugin}, + xfrs: []*xfr{ + { + Zones: []string{"example.org."}, + to: []string{"1.2.3.4"}, + }, + }, + Next: &nextPlugin, + } + + ctx := context.TODO() + w := dnstest.NewRecorder(&test.ResponseWriter{TCP: true}) + m := &dns.Msg{} + m.SetAxfr(transfer.xfrs[0].Zones[0]) + + _, err := transfer.ServeDNS(ctx, w, m) + + if err != nil { + t.Error(err) + } + + if w.Msg.Rcode != dns.RcodeRefused { + t.Errorf("Expected REFUSED response code, got %s", dns.RcodeToString[w.Msg.Rcode]) + } +} diff --git a/plugin/tsig/README.md b/plugin/tsig/README.md new file mode 100644 index 0000000..d73b9ca --- /dev/null +++ b/plugin/tsig/README.md @@ -0,0 +1,118 @@ +# tsig + +## Name + +*tsig* - define TSIG keys, validate incoming TSIG signed requests and sign responses. + +## Description + +With *tsig*, you can define CoreDNS's TSIG secret keys. Using those keys, *tsig* validates incoming TSIG requests and signs +responses to those requests. It does not itself sign requests outgoing from CoreDNS; it is up to the +respective plugins sending those requests to sign them using the keys defined by *tsig*. + +The *tsig* plugin can also require that incoming requests be signed for certain query types, refusing requests that do not comply. + +## Syntax + +~~~ +tsig [ZONE...] { + secret NAME KEY + secrets FILE + require [QTYPE...] +} +~~~ + + * **ZONE** - the zones *tsig* will TSIG. By default, the zones from the server block are used. + + * `secret` **NAME** **KEY** - specifies a TSIG secret for **NAME** with **KEY**. Use this option more than once + to define multiple secrets. Secrets are global to the server instance, not just for the enclosing **ZONE**. + + * `secrets` **FILE** - same as `secret`, but load the secrets from a file. The file may define any number + of unique keys, each in the following `named.conf` format: + ```cgo + key "example." { + secret "X28hl0BOfAL5G0jsmJWSacrwn7YRm2f6U5brnzwWEus="; + }; + ``` + Each key may also specify an `algorithm` e.g. `algorithm hmac-sha256;`, but this is currently ignored by the plugin. + + * `require` **QTYPE...** - the query types that must be TSIG'd. Requests of the specified types + will be `REFUSED` if they are not signed.`require all` will require requests of all types to be + signed. `require none` will not require requests any types to be signed. Default behavior is to not require. + +## Examples + +Require TSIG signed transactions for transfer requests to `example.zone`. + +``` +example.zone { + tsig { + secret example.zone.key. NoTCJU+DMqFWywaPyxSijrDEA/eC3nK0xi3AMEZuPVk= + require AXFR IXFR + } + transfer { + to * + } +} +``` + +Require TSIG signed transactions for all requests to `auth.zone`. + +``` +auth.zone { + tsig { + secret auth.zone.key. NoTCJU+DMqFWywaPyxSijrDEA/eC3nK0xi3AMEZuPVk= + require all + } + forward . 10.1.0.2 +} +``` + +## Bugs + +### Secondary + +TSIG transfers are not yet implemented for the *secondary* plugin. The *secondary* plugin will not sign its zone transfer requests. + +### Zone Transfer Notifies + +With the *transfer* plugin, zone transfer notifications from CoreDNS are not TSIG signed. + +### Special Considerations for Forwarding Servers (RFC 8945 5.5) + +https://datatracker.ietf.org/doc/html/rfc8945#section-5.5 + +CoreDNS does not implement this section as follows ... + +* RFC requirement: + > If the name on the TSIG is not +of a secret that the server shares with the originator, the server +MUST forward the message unchanged including the TSIG. + + CoreDNS behavior: +If ths zone of the request matches the _tsig_ plugin zones, then the TSIG record +is always stripped. But even when the _tsig_ plugin is not involved, the _forward_ plugin +may alter the message with compression, which would cause validation failure +at the destination. + + +* RFC requirement: + > If the TSIG passes all checks, the forwarding +server MUST, if possible, include a TSIG of its own to the +destination or the next forwarder. + + CoreDNS behavior: +If ths zone of the request matches the _tsig_ plugin zones, _forward_ plugin will +proxy the request upstream without TSIG. + + +* RFC requirement: + > If no transaction security is +available to the destination and the message is a query, and if the +corresponding response has the AD flag (see RFC4035) set, the +forwarder MUST clear the AD flag before adding the TSIG to the +response and returning the result to the system from which it +received the query. + + CoreDNS behavior: +The AD flag is not cleared. diff --git a/plugin/tsig/setup.go b/plugin/tsig/setup.go new file mode 100644 index 0000000..a187a4b --- /dev/null +++ b/plugin/tsig/setup.go @@ -0,0 +1,168 @@ +package tsig + +import ( + "bufio" + "fmt" + "io" + "os" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + + "github.com/miekg/dns" +) + +func init() { + caddy.RegisterPlugin(pluginName, caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +func setup(c *caddy.Controller) error { + t, err := parse(c) + if err != nil { + return plugin.Error(pluginName, c.ArgErr()) + } + + config := dnsserver.GetConfig(c) + + config.TsigSecret = t.secrets + + config.AddPlugin(func(next plugin.Handler) plugin.Handler { + t.Next = next + return t + }) + + return nil +} + +func parse(c *caddy.Controller) (*TSIGServer, error) { + t := &TSIGServer{ + secrets: make(map[string]string), + types: defaultQTypes, + } + + for i := 0; c.Next(); i++ { + if i > 0 { + return nil, plugin.ErrOnce + } + + t.Zones = plugin.OriginsFromArgsOrServerBlock(c.RemainingArgs(), c.ServerBlockKeys) + for c.NextBlock() { + switch c.Val() { + case "secret": + args := c.RemainingArgs() + if len(args) != 2 { + return nil, c.ArgErr() + } + k := plugin.Name(args[0]).Normalize() + if _, exists := t.secrets[k]; exists { + return nil, fmt.Errorf("key %q redefined", k) + } + t.secrets[k] = args[1] + case "secrets": + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + f, err := os.Open(args[0]) + if err != nil { + return nil, err + } + secrets, err := parseKeyFile(f) + if err != nil { + return nil, err + } + for k, s := range secrets { + if _, exists := t.secrets[k]; exists { + return nil, fmt.Errorf("key %q redefined", k) + } + t.secrets[k] = s + } + case "require": + t.types = qTypes{} + args := c.RemainingArgs() + if len(args) == 0 { + return nil, c.ArgErr() + } + if args[0] == "all" { + t.all = true + continue + } + if args[0] == "none" { + continue + } + for _, str := range args { + qt, ok := dns.StringToType[str] + if !ok { + return nil, c.Errf("unknown query type '%s'", str) + } + t.types[qt] = struct{}{} + } + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + } + return t, nil +} + +func parseKeyFile(f io.Reader) (map[string]string, error) { + secrets := make(map[string]string) + s := bufio.NewScanner(f) + for s.Scan() { + fields := strings.Fields(s.Text()) + if len(fields) == 0 { + continue + } + if fields[0] != "key" { + return nil, fmt.Errorf("unexpected token %q", fields[0]) + } + if len(fields) < 2 { + return nil, fmt.Errorf("expected key name %q", s.Text()) + } + key := strings.Trim(fields[1], "\"{") + if len(key) == 0 { + return nil, fmt.Errorf("expected key name %q", s.Text()) + } + key = plugin.Name(key).Normalize() + if _, ok := secrets[key]; ok { + return nil, fmt.Errorf("key %q redefined", key) + } + key: + for s.Scan() { + fields := strings.Fields(s.Text()) + if len(fields) == 0 { + continue + } + switch fields[0] { + case "algorithm": + continue + case "secret": + if len(fields) < 2 { + return nil, fmt.Errorf("expected secret key %q", s.Text()) + } + secret := strings.Trim(fields[1], "\";") + if len(secret) == 0 { + return nil, fmt.Errorf("expected secret key %q", s.Text()) + } + secrets[key] = secret + case "}": + fallthrough + case "};": + break key + default: + return nil, fmt.Errorf("unexpected token %q", fields[0]) + } + } + if _, ok := secrets[key]; !ok { + return nil, fmt.Errorf("expected secret for key %q", key) + } + } + return secrets, nil +} + +var defaultQTypes = qTypes{} diff --git a/plugin/tsig/setup_test.go b/plugin/tsig/setup_test.go new file mode 100644 index 0000000..0d74339 --- /dev/null +++ b/plugin/tsig/setup_test.go @@ -0,0 +1,245 @@ +package tsig + +import ( + "fmt" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestParse(t *testing.T) { + secrets := map[string]string{ + "name.key.": "test-key", + "name2.key.": "test-key-2", + } + secretConfig := "" + for k, s := range secrets { + secretConfig += fmt.Sprintf("secret %s %s\n", k, s) + } + secretsFile, cleanup, err := test.TempFile(".", `key "name.key." { + secret "test-key"; +}; +key "name2.key." { + secret "test-key2"; +};`) + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer cleanup() + + tests := []struct { + input string + shouldErr bool + expectedZones []string + expectedQTypes qTypes + expectedSecrets map[string]string + expectedAll bool + }{ + { + input: "tsig {\n " + secretConfig + "}", + expectedZones: []string{"."}, + expectedQTypes: defaultQTypes, + expectedSecrets: secrets, + }, + { + input: "tsig {\n secrets " + secretsFile + "\n}", + expectedZones: []string{"."}, + expectedQTypes: defaultQTypes, + expectedSecrets: secrets, + }, + { + input: "tsig example.com {\n " + secretConfig + "}", + expectedZones: []string{"example.com."}, + expectedQTypes: defaultQTypes, + expectedSecrets: secrets, + }, + { + input: "tsig {\n " + secretConfig + " require all \n}", + expectedZones: []string{"."}, + expectedQTypes: qTypes{}, + expectedAll: true, + expectedSecrets: secrets, + }, + { + input: "tsig {\n " + secretConfig + " require none \n}", + expectedZones: []string{"."}, + expectedQTypes: qTypes{}, + expectedAll: false, + expectedSecrets: secrets, + }, + { + input: "tsig {\n " + secretConfig + " \n require A AAAA \n}", + expectedZones: []string{"."}, + expectedQTypes: qTypes{dns.TypeA: {}, dns.TypeAAAA: {}}, + expectedSecrets: secrets, + }, + { + input: "tsig {\n blah \n}", + shouldErr: true, + }, + { + input: "tsig {\n secret name. too many parameters \n}", + shouldErr: true, + }, + { + input: "tsig {\n require \n}", + shouldErr: true, + }, + { + input: "tsig {\n require invalid-qtype \n}", + shouldErr: true, + }, + } + + serverBlockKeys := []string{"."} + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + c.ServerBlockKeys = serverBlockKeys + ts, err := parse(c) + + if err == nil && test.shouldErr { + t.Fatalf("Test %d expected errors, but got no error.", i) + } else if err != nil && !test.shouldErr { + t.Fatalf("Test %d expected no errors, but got '%v'", i, err) + } + + if test.shouldErr { + continue + } + + if len(test.expectedZones) != len(ts.Zones) { + t.Fatalf("Test %d expected zones '%v', but got '%v'.", i, test.expectedZones, ts.Zones) + } + for j := range test.expectedZones { + if test.expectedZones[j] != ts.Zones[j] { + t.Errorf("Test %d expected zones '%v', but got '%v'.", i, test.expectedZones, ts.Zones) + break + } + } + + if test.expectedAll != ts.all { + t.Errorf("Test %d expected require all to be '%v', but got '%v'.", i, test.expectedAll, ts.all) + } + + if len(test.expectedQTypes) != len(ts.types) { + t.Fatalf("Test %d expected required types '%v', but got '%v'.", i, test.expectedQTypes, ts.types) + } + for qt := range test.expectedQTypes { + if _, ok := ts.types[qt]; !ok { + t.Errorf("Test %d required types '%v', but got '%v'.", i, test.expectedQTypes, ts.types) + break + } + } + + if len(test.expectedSecrets) != len(ts.secrets) { + t.Fatalf("Test %d expected secrets '%v', but got '%v'.", i, test.expectedSecrets, ts.secrets) + } + for qt := range test.expectedSecrets { + secret, ok := ts.secrets[qt] + if !ok { + t.Errorf("Test %d required secrets '%v', but got '%v'.", i, test.expectedSecrets, ts.secrets) + break + } + if secret != ts.secrets[qt] { + t.Errorf("Test %d required secrets '%v', but got '%v'.", i, test.expectedSecrets, ts.secrets) + break + } + } + } +} + +func TestParseKeyFile(t *testing.T) { + var reader = strings.NewReader(`key "foo" { + algorithm hmac-sha256; + secret "36eowrtmxceNA3T5AdE+JNUOWFCw3amtcyHACnrDVgQ="; +}; +key "bar" { + algorithm hmac-sha256; + secret "X28hl0BOfAL5G0jsmJWSacrwn7YRm2f6U5brnzwWEus="; +}; +key "baz" { + secret "BycDPXSx/5YCD44Q4g5Nd2QNxNRDKwWTXddrU/zpIQM="; +};`) + + secrets, err := parseKeyFile(reader) + if err != nil { + t.Fatalf("Unexpected error: %q", err) + } + expectedSecrets := map[string]string{ + "foo.": "36eowrtmxceNA3T5AdE+JNUOWFCw3amtcyHACnrDVgQ=", + "bar.": "X28hl0BOfAL5G0jsmJWSacrwn7YRm2f6U5brnzwWEus=", + "baz.": "BycDPXSx/5YCD44Q4g5Nd2QNxNRDKwWTXddrU/zpIQM=", + } + + if len(secrets) != len(expectedSecrets) { + t.Fatalf("result has %d keys. expected %d", len(secrets), len(expectedSecrets)) + } + + for k, sec := range secrets { + expectedSec, ok := expectedSecrets[k] + if !ok { + t.Errorf("unexpected key in result. %q", k) + continue + } + if sec != expectedSec { + t.Errorf("incorrect secret in result for key %q. expected %q got %q ", k, expectedSec, sec) + } + } +} + +func TestParseKeyFileErrors(t *testing.T) { + tests := []struct { + in string + err string + }{ + {in: `key {`, err: "expected key name \"key {\""}, + {in: `foo "key" {`, err: "unexpected token \"foo\""}, + { + in: `key "foo" { + secret "36eowrtmxceNA3T5AdE+JNUOWFCw3amtcyHACnrDVgQ="; + }; + key "foo" { + secret "X28hl0BOfAL5G0jsmJWSacrwn7YRm2f6U5brnzwWEus="; + }; `, + err: "key \"foo.\" redefined", + }, + {in: `key "foo" { + schmalgorithm hmac-sha256;`, + err: "unexpected token \"schmalgorithm\"", + }, + { + in: `key "foo" { + schmecret "36eowrtmxceNA3T5AdE+JNUOWFCw3amtcyHACnrDVgQ=";`, + err: "unexpected token \"schmecret\"", + }, + { + in: `key "foo" { + secret`, + err: "expected secret key \"\\tsecret\"", + }, + { + in: `key "foo" { + secret ;`, + err: "expected secret key \"\\tsecret ;\"", + }, + { + in: `key "foo" { + };`, + err: "expected secret for key \"foo.\"", + }, + } + for i, testcase := range tests { + _, err := parseKeyFile(strings.NewReader(testcase.in)) + if err == nil { + t.Errorf("Test %d: expected error, got no error", i) + continue + } + if err.Error() != testcase.err { + t.Errorf("Test %d: Expected error: %q, got %q", i, testcase.err, err.Error()) + } + } +} diff --git a/plugin/tsig/tsig.go b/plugin/tsig/tsig.go new file mode 100644 index 0000000..6441c8a --- /dev/null +++ b/plugin/tsig/tsig.go @@ -0,0 +1,140 @@ +package tsig + +import ( + "context" + "encoding/binary" + "encoding/hex" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// TSIGServer verifies tsig status and adds tsig to responses +type TSIGServer struct { + Zones []string + secrets map[string]string // [key-name]secret + types qTypes + all bool + Next plugin.Handler +} + +type qTypes map[uint16]struct{} + +// Name implements plugin.Handler +func (t TSIGServer) Name() string { return pluginName } + +// ServeDNS implements plugin.Handler +func (t *TSIGServer) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + var err error + state := request.Request{Req: r, W: w} + if z := plugin.Zones(t.Zones).Matches(state.Name()); z == "" { + return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + } + + var tsigRR = r.IsTsig() + rcode := dns.RcodeSuccess + if !t.tsigRequired(state.QType()) && tsigRR == nil { + return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + } + + if tsigRR == nil { + log.Debugf("rejecting '%s' request without TSIG\n", dns.TypeToString[state.QType()]) + rcode = dns.RcodeRefused + } + + // wrap the response writer so the response will be TSIG signed. + w = &restoreTsigWriter{w, r, tsigRR} + + tsigStatus := w.TsigStatus() + if tsigStatus != nil { + log.Debugf("TSIG validation failed: %v %v", dns.TypeToString[state.QType()], tsigStatus) + rcode = dns.RcodeNotAuth + switch tsigStatus { + case dns.ErrSecret: + tsigRR.Error = dns.RcodeBadKey + case dns.ErrTime: + tsigRR.Error = dns.RcodeBadTime + default: + tsigRR.Error = dns.RcodeBadSig + } + resp := new(dns.Msg).SetRcode(r, rcode) + w.WriteMsg(resp) + return dns.RcodeSuccess, nil + } + + // strip the TSIG RR. Next, and subsequent plugins will not see the TSIG RRs. + // This violates forwarding cases (RFC 8945 5.5). See README.md Bugs + if len(r.Extra) > 1 { + r.Extra = r.Extra[0 : len(r.Extra)-1] + } else { + r.Extra = []dns.RR{} + } + + if rcode == dns.RcodeSuccess { + rcode, err = plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) + if err != nil { + log.Errorf("request handler returned an error: %v\n", err) + } + } + // If the plugin chain result was not an error, restore the TSIG and write the response. + if !plugin.ClientWrite(rcode) { + resp := new(dns.Msg).SetRcode(r, rcode) + w.WriteMsg(resp) + } + return dns.RcodeSuccess, nil +} + +func (t *TSIGServer) tsigRequired(qtype uint16) bool { + if t.all { + return true + } + if _, ok := t.types[qtype]; ok { + return true + } + return false +} + +// restoreTsigWriter Implement Response Writer, and adds a TSIG RR to a response +type restoreTsigWriter struct { + dns.ResponseWriter + req *dns.Msg // original request excluding TSIG if it has one + reqTSIG *dns.TSIG // original TSIG +} + +// WriteMsg adds a TSIG RR to the response +func (r *restoreTsigWriter) WriteMsg(m *dns.Msg) error { + // Make sure the response has an EDNS OPT RR if the request had it. + // Otherwise ScrubWriter would append it *after* TSIG, making it a non-compliant DNS message. + state := request.Request{Req: r.req, W: r.ResponseWriter} + state.SizeAndDo(m) + + repTSIG := m.IsTsig() + if r.reqTSIG != nil && repTSIG == nil { + repTSIG = new(dns.TSIG) + repTSIG.Hdr = dns.RR_Header{Name: r.reqTSIG.Hdr.Name, Rrtype: dns.TypeTSIG, Class: dns.ClassANY} + repTSIG.Algorithm = r.reqTSIG.Algorithm + repTSIG.OrigId = m.MsgHdr.Id + repTSIG.Error = r.reqTSIG.Error + repTSIG.MAC = r.reqTSIG.MAC + repTSIG.MACSize = r.reqTSIG.MACSize + if repTSIG.Error == dns.RcodeBadTime { + // per RFC 8945 5.2.3. client time goes into TimeSigned, server time in OtherData, OtherLen = 6 ... + repTSIG.TimeSigned = r.reqTSIG.TimeSigned + b := make([]byte, 8) + // TimeSigned is network byte order. + binary.BigEndian.PutUint64(b, uint64(time.Now().Unix())) + // truncate to 48 least significant bits (network order 6 rightmost bytes) + repTSIG.OtherData = hex.EncodeToString(b[2:]) + repTSIG.OtherLen = 6 + } + m.Extra = append(m.Extra, repTSIG) + } + + return r.ResponseWriter.WriteMsg(m) +} + +const pluginName = "tsig" diff --git a/plugin/tsig/tsig_test.go b/plugin/tsig/tsig_test.go new file mode 100644 index 0000000..f7ec1fd --- /dev/null +++ b/plugin/tsig/tsig_test.go @@ -0,0 +1,255 @@ +package tsig + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestServeDNS(t *testing.T) { + cases := []struct { + zones []string + reqTypes qTypes + qType uint16 + qTsig, all bool + expectRcode int + expectTsig bool + statusError bool + }{ + { + zones: []string{"."}, + all: true, + qType: dns.TypeA, + qTsig: true, + expectRcode: dns.RcodeSuccess, + expectTsig: true, + }, + { + zones: []string{"."}, + all: true, + qType: dns.TypeA, + qTsig: false, + expectRcode: dns.RcodeRefused, + expectTsig: false, + }, + { + zones: []string{"another.domain."}, + all: true, + qType: dns.TypeA, + qTsig: false, + expectRcode: dns.RcodeSuccess, + expectTsig: false, + }, + { + zones: []string{"another.domain."}, + all: true, + qType: dns.TypeA, + qTsig: true, + expectRcode: dns.RcodeSuccess, + expectTsig: false, + }, + { + zones: []string{"."}, + reqTypes: qTypes{dns.TypeAXFR: {}}, + qType: dns.TypeAXFR, + qTsig: true, + expectRcode: dns.RcodeSuccess, + expectTsig: true, + }, + { + zones: []string{"."}, + reqTypes: qTypes{}, + qType: dns.TypeA, + qTsig: false, + expectRcode: dns.RcodeSuccess, + expectTsig: false, + }, + { + zones: []string{"."}, + reqTypes: qTypes{}, + qType: dns.TypeA, + qTsig: true, + expectRcode: dns.RcodeSuccess, + expectTsig: true, + }, + { + zones: []string{"."}, + all: true, + qType: dns.TypeA, + qTsig: true, + expectRcode: dns.RcodeNotAuth, + expectTsig: true, + statusError: true, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + tsig := TSIGServer{ + Zones: tc.zones, + all: tc.all, + types: tc.reqTypes, + Next: testHandler(), + } + + ctx := context.TODO() + + var w *dnstest.Recorder + if tc.statusError { + w = dnstest.NewRecorder(&ErrWriter{err: dns.ErrSig}) + } else { + w = dnstest.NewRecorder(&test.ResponseWriter{}) + } + r := new(dns.Msg) + r.SetQuestion("test.example.", tc.qType) + if tc.qTsig { + r.SetTsig("test.key.", dns.HmacSHA256, 300, time.Now().Unix()) + } + + _, err := tsig.ServeDNS(ctx, w, r) + if err != nil { + t.Fatal(err) + } + + if w.Msg.Rcode != tc.expectRcode { + t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode) + } + + if ts := w.Msg.IsTsig(); ts == nil && tc.expectTsig { + t.Fatal("expected TSIG in response") + } + if ts := w.Msg.IsTsig(); ts != nil && !tc.expectTsig { + t.Fatal("expected no TSIG in response") + } + }) + } +} + +func TestServeDNSTsigErrors(t *testing.T) { + clientNow := time.Now().Unix() + + cases := []struct { + desc string + tsigErr error + expectRcode int + expectError int + expectOtherLength int + expectTimeSigned int64 + }{ + { + desc: "Unknown Key", + tsigErr: dns.ErrSecret, + expectRcode: dns.RcodeNotAuth, + expectError: dns.RcodeBadKey, + expectOtherLength: 0, + expectTimeSigned: 0, + }, + { + desc: "Bad Signature", + tsigErr: dns.ErrSig, + expectRcode: dns.RcodeNotAuth, + expectError: dns.RcodeBadSig, + expectOtherLength: 0, + expectTimeSigned: 0, + }, + { + desc: "Bad Time", + tsigErr: dns.ErrTime, + expectRcode: dns.RcodeNotAuth, + expectError: dns.RcodeBadTime, + expectOtherLength: 6, + expectTimeSigned: clientNow, + }, + } + + tsig := TSIGServer{ + Zones: []string{"."}, + all: true, + Next: testHandler(), + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.TODO() + + var w *dnstest.Recorder + + w = dnstest.NewRecorder(&ErrWriter{err: tc.tsigErr}) + + r := new(dns.Msg) + r.SetQuestion("test.example.", dns.TypeA) + r.SetTsig("test.key.", dns.HmacSHA256, 300, clientNow) + + // set a fake MAC and Size in request + rtsig := r.IsTsig() + rtsig.MAC = "0123456789012345678901234567890101234567890123456789012345678901" + rtsig.MACSize = 32 + + _, err := tsig.ServeDNS(ctx, w, r) + if err != nil { + t.Fatal(err) + } + + if w.Msg.Rcode != tc.expectRcode { + t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode) + } + + ts := w.Msg.IsTsig() + + if ts == nil { + t.Fatal("expected TSIG in response") + } + + if int(ts.Error) != tc.expectError { + t.Errorf("expected TSIG error code %v, got %v", tc.expectError, ts.Error) + } + + if len(ts.OtherData)/2 != tc.expectOtherLength { + t.Errorf("expected Other of length %v, got %v", tc.expectOtherLength, len(ts.OtherData)) + } + + if int(ts.OtherLen) != tc.expectOtherLength { + t.Errorf("expected OtherLen %v, got %v", tc.expectOtherLength, ts.OtherLen) + } + + if ts.TimeSigned != uint64(tc.expectTimeSigned) { + t.Errorf("expected TimeSigned to be %v, got %v", tc.expectTimeSigned, ts.TimeSigned) + } + }) + } +} + +func testHandler() test.HandlerFunc { + return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + qname := state.Name() + m := new(dns.Msg) + rcode := dns.RcodeServerFailure + if qname == "test.example." { + m.SetReply(r) + rr := test.A("test.example. 300 IN A 1.2.3.48") + m.Answer = []dns.RR{rr} + m.Authoritative = true + rcode = dns.RcodeSuccess + } + m.SetRcode(r, rcode) + w.WriteMsg(m) + return rcode, nil + } +} + +// a test.ResponseWriter that always returns err as the TSIG status error +type ErrWriter struct { + err error + test.ResponseWriter +} + +// TsigStatus always returns an error. +func (t *ErrWriter) TsigStatus() error { return t.err } diff --git a/plugin/v64dns/README.md b/plugin/v64dns/README.md new file mode 100644 index 0000000..6868d0a --- /dev/null +++ b/plugin/v64dns/README.md @@ -0,0 +1,28 @@ +## 语法格式 +``` +v64dns 要负责解析的域名 { +v4ns <v4NS的子域名> <v4NS的ip地址> +v6ns <v6NS的子域名> <v6NS的ip地址> +chain <v4-only子域名称> <v6-only子域名称> +ip-embed 是否将请求ip嵌入到响应记录中 +analyze <图数据库类型(目前仅支持neo4j)> <图数据库地址:端口(不指定则采用默认)> <用户名> <密码> +} +``` +v64DNS将负责解析所给域名下的v4-only域名和v6-only域名, +最终返回结果为解析器ip链组成的TXT记录,格式为ip1 ip2 ip3 ... +需要搭配探针使用,探针格式为 +`[进度标识].[随机数].[实验ID].[水印].[子域名].[实验域名]` +v64DNS默认还集成了分析脚本,添加analyze参数即可 + +## example + +``` +v64ns example.com { +v4ns ns4 1.1.1.1 +v6ns ns6 1:1::1 +chain v4 v6 +ip-embed +analyze neo4j 1.1.1.1:7474 test test123 +} + +```
\ No newline at end of file diff --git a/plugin/v64dns/analyze/analyze.go b/plugin/v64dns/analyze/analyze.go new file mode 100644 index 0000000..514f2ac --- /dev/null +++ b/plugin/v64dns/analyze/analyze.go @@ -0,0 +1,101 @@ +package analyze + +import ( + "context" + "ohmydns2/plugin/v64dns/analyze/pb" + "sync" + "sync/atomic" + "unsafe" + + olog "ohmydns2/plugin/pkg/log" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type Analyzer struct { + Graphtype string + Graphuri string + GraphUser string + GraphPass string + IsA bool + grpcClient pb.GrpcServiceClient +} + +var ( + A Analyzer + c unsafe.Pointer + lck sync.Mutex +) + +func NewAnalyzer(gtype, guri, guser, gpass string) (*Analyzer, error) { + B := Analyzer{ + Graphtype: gtype, + Graphuri: guri, + GraphUser: guser, + GraphPass: gpass, + IsA: true, + grpcClient: nil, + } + conn, err := B.GetConn("127.0.0.1:56789") + if err != nil { + olog.Errorf("GetGrpcClient:%s", err.Error()) + panic("error") + } + B.grpcClient = pb.NewGrpcServiceClient(conn) + return &B, nil +} + +func (a Analyzer) GetConn(target string) (*grpc.ClientConn, error) { + if atomic.LoadPointer(&c) != nil { + return (*grpc.ClientConn)(c), nil + } + lck.Lock() + defer lck.Unlock() + if atomic.LoadPointer(&c) != nil { //double check + return (*grpc.ClientConn)(c), nil + } + cli, err := newGrpcConn(target) + if err != nil { + return nil, err + } + atomic.StorePointer(&c, unsafe.Pointer(cli)) + return cli, nil +} + +// 新建grpc连接 +func newGrpcConn(target string) (*grpc.ClientConn, error) { + conn, err := grpc.Dial( + target, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + olog.Errorf("连接服务端失败: %s", err) + return nil, err + } + return conn, nil +} + +// 远程调用Python脚本方法 +func (a Analyzer) Go2py(i []string) { + lck.Lock() + defer lck.Unlock() + // 调用服务端函数 + r, err := a.grpcClient.AnalyzeService(context.Background(), &pb.DnsChain{ + Gtype: a.Graphtype, + Guri: a.Graphuri, + Guser: a.GraphUser, + Gpass: a.GraphPass, + Data: i, + }) + if err != nil { + print(r) + olog.Errorf("调用解析链分析器代码失败: %s", err) + return + } + // 处理不成功则显示警告信息 + if r.Res != "success" { + olog.Warning(r.Res) + } + +} diff --git a/plugin/v64dns/analyze/analyze_test.go b/plugin/v64dns/analyze/analyze_test.go new file mode 100644 index 0000000..3e69627 --- /dev/null +++ b/plugin/v64dns/analyze/analyze_test.go @@ -0,0 +1,42 @@ +package analyze + +import ( + "ohmydns2/plugin/v64dns/analyze/pb" + "testing" +) + +func TestAnalyzer_go2py(t *testing.T) { + type fields struct { + Graphtype string + Graphuri string + GraphUser string + GraphPass string + IsA bool + grpcClient pb.GrpcServiceClient + } + type args struct { + i []string + } + tests := []struct { + name string + fields fields + args args + }{ + {"test", + fields{ + Graphtype: "neo4j", + Graphuri: "neo4j+s://600c8c12.databases.neo4j.io:7687", + GraphUser: "neo4j", + GraphPass: "5PvtszWJ7ru4hyBTKO9mGnavOJETTGhJYQIeC5SzmPQ", + IsA: true, + }, + args{i: []string{"1.1.1.1", "2001::1"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + NewAnalyzer(tt.fields.Graphtype, tt.fields.Graphuri, tt.fields.GraphUser, tt.fields.GraphPass) + A.Go2py(tt.args.i) + }) + } +} diff --git a/plugin/v64dns/analyze/pb/analyze.pb.go b/plugin/v64dns/analyze/pb/analyze.pb.go new file mode 100644 index 0000000..dea9a30 --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyze.pb.go @@ -0,0 +1,251 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.12.4 +// source: analyze.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type DnsChain struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Gtype string `protobuf:"bytes,1,opt,name=gtype,proto3" json:"gtype,omitempty"` + Guri string `protobuf:"bytes,2,opt,name=guri,proto3" json:"guri,omitempty"` + Guser string `protobuf:"bytes,3,opt,name=guser,proto3" json:"guser,omitempty"` + Gpass string `protobuf:"bytes,4,opt,name=gpass,proto3" json:"gpass,omitempty"` + Data []string `protobuf:"bytes,5,rep,name=data,proto3" json:"data,omitempty"` +} + +func (x *DnsChain) Reset() { + *x = DnsChain{} + if protoimpl.UnsafeEnabled { + mi := &file_analyze_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DnsChain) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DnsChain) ProtoMessage() {} + +func (x *DnsChain) ProtoReflect() protoreflect.Message { + mi := &file_analyze_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DnsChain.ProtoReflect.Descriptor instead. +func (*DnsChain) Descriptor() ([]byte, []int) { + return file_analyze_proto_rawDescGZIP(), []int{0} +} + +func (x *DnsChain) GetGtype() string { + if x != nil { + return x.Gtype + } + return "" +} + +func (x *DnsChain) GetGuri() string { + if x != nil { + return x.Guri + } + return "" +} + +func (x *DnsChain) GetGuser() string { + if x != nil { + return x.Guser + } + return "" +} + +func (x *DnsChain) GetGpass() string { + if x != nil { + return x.Gpass + } + return "" +} + +func (x *DnsChain) GetData() []string { + if x != nil { + return x.Data + } + return nil +} + +type Result struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Res string `protobuf:"bytes,1,opt,name=res,proto3" json:"res,omitempty"` +} + +func (x *Result) Reset() { + *x = Result{} + if protoimpl.UnsafeEnabled { + mi := &file_analyze_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Result) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Result) ProtoMessage() {} + +func (x *Result) ProtoReflect() protoreflect.Message { + mi := &file_analyze_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Result.ProtoReflect.Descriptor instead. +func (*Result) Descriptor() ([]byte, []int) { + return file_analyze_proto_rawDescGZIP(), []int{1} +} + +func (x *Result) GetRes() string { + if x != nil { + return x.Res + } + return "" +} + +var File_analyze_proto protoreflect.FileDescriptor + +var file_analyze_proto_rawDesc = []byte{ + 0x0a, 0x0d, 0x61, 0x6e, 0x61, 0x6c, 0x79, 0x7a, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x1e, 0x6f, 0x68, 0x6d, 0x79, 0x64, 0x6e, 0x73, 0x32, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, + 0x2e, 0x76, 0x36, 0x34, 0x44, 0x4e, 0x53, 0x2e, 0x61, 0x6e, 0x61, 0x6c, 0x79, 0x7a, 0x65, 0x22, + 0x74, 0x0a, 0x08, 0x44, 0x6e, 0x73, 0x43, 0x68, 0x61, 0x69, 0x6e, 0x12, 0x14, 0x0a, 0x05, 0x67, + 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x74, 0x79, 0x70, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x67, 0x75, 0x72, 0x69, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x67, 0x75, 0x72, 0x69, 0x12, 0x14, 0x0a, 0x05, 0x67, 0x75, 0x73, 0x65, 0x72, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x75, 0x73, 0x65, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x67, + 0x70, 0x61, 0x73, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x70, 0x61, 0x73, + 0x73, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x04, 0x64, 0x61, 0x74, 0x61, 0x22, 0x1a, 0x0a, 0x06, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, + 0x10, 0x0a, 0x03, 0x72, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x72, 0x65, + 0x73, 0x32, 0x71, 0x0a, 0x0b, 0x47, 0x72, 0x70, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x12, 0x62, 0x0a, 0x0e, 0x41, 0x6e, 0x61, 0x6c, 0x79, 0x7a, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x28, 0x2e, 0x6f, 0x68, 0x6d, 0x79, 0x64, 0x6e, 0x73, 0x32, 0x2e, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x76, 0x36, 0x34, 0x44, 0x4e, 0x53, 0x2e, 0x61, 0x6e, 0x61, 0x6c, + 0x79, 0x7a, 0x65, 0x2e, 0x44, 0x6e, 0x73, 0x43, 0x68, 0x61, 0x69, 0x6e, 0x1a, 0x26, 0x2e, 0x6f, + 0x68, 0x6d, 0x79, 0x64, 0x6e, 0x73, 0x32, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x76, + 0x36, 0x34, 0x44, 0x4e, 0x53, 0x2e, 0x61, 0x6e, 0x61, 0x6c, 0x79, 0x7a, 0x65, 0x2e, 0x72, 0x65, + 0x73, 0x75, 0x6c, 0x74, 0x42, 0x06, 0x5a, 0x04, 0x2e, 0x3b, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_analyze_proto_rawDescOnce sync.Once + file_analyze_proto_rawDescData = file_analyze_proto_rawDesc +) + +func file_analyze_proto_rawDescGZIP() []byte { + file_analyze_proto_rawDescOnce.Do(func() { + file_analyze_proto_rawDescData = protoimpl.X.CompressGZIP(file_analyze_proto_rawDescData) + }) + return file_analyze_proto_rawDescData +} + +var file_analyze_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_analyze_proto_goTypes = []interface{}{ + (*DnsChain)(nil), // 0: ohmydns2.plugin.v64DNS.analyze.DnsChain + (*Result)(nil), // 1: ohmydns2.plugin.v64DNS.analyze.result +} +var file_analyze_proto_depIdxs = []int32{ + 0, // 0: ohmydns2.plugin.v64DNS.analyze.GrpcService.AnalyzeService:input_type -> ohmydns2.plugin.v64DNS.analyze.DnsChain + 1, // 1: ohmydns2.plugin.v64DNS.analyze.GrpcService.AnalyzeService:output_type -> ohmydns2.plugin.v64DNS.analyze.result + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_analyze_proto_init() } +func file_analyze_proto_init() { + if File_analyze_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_analyze_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DnsChain); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_analyze_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Result); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_analyze_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_analyze_proto_goTypes, + DependencyIndexes: file_analyze_proto_depIdxs, + MessageInfos: file_analyze_proto_msgTypes, + }.Build() + File_analyze_proto = out.File + file_analyze_proto_rawDesc = nil + file_analyze_proto_goTypes = nil + file_analyze_proto_depIdxs = nil +} diff --git a/plugin/v64dns/analyze/pb/analyze.proto b/plugin/v64dns/analyze/pb/analyze.proto new file mode 100644 index 0000000..fffbeea --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyze.proto @@ -0,0 +1,20 @@ +syntax="proto3"; + +package ohmydns2.plugin.v64DNS.analyze; +option go_package=".;pb"; + +message DnsChain { + string gtype=1; + string guri=2; + string guser=3; + string gpass=4; + repeated string data=5; +} + +message result { + string res=1; +} + +service GrpcService { + rpc AnalyzeService(DnsChain) returns (result); +}
\ No newline at end of file diff --git a/plugin/v64dns/analyze/pb/analyze_grpc.pb.go b/plugin/v64dns/analyze/pb/analyze_grpc.pb.go new file mode 100644 index 0000000..10edea8 --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyze_grpc.pb.go @@ -0,0 +1,105 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.12.4 +// source: analyze.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// GrpcServiceClient is the client API for GrpcService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type GrpcServiceClient interface { + AnalyzeService(ctx context.Context, in *DnsChain, opts ...grpc.CallOption) (*Result, error) +} + +type grpcServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewGrpcServiceClient(cc grpc.ClientConnInterface) GrpcServiceClient { + return &grpcServiceClient{cc} +} + +func (c *grpcServiceClient) AnalyzeService(ctx context.Context, in *DnsChain, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/ohmydns2.plugin.v64DNS.analyze.GrpcService/AnalyzeService", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// GrpcServiceServer is the server API for GrpcService service. +// All implementations must embed UnimplementedGrpcServiceServer +// for forward compatibility +type GrpcServiceServer interface { + AnalyzeService(context.Context, *DnsChain) (*Result, error) + mustEmbedUnimplementedGrpcServiceServer() +} + +// UnimplementedGrpcServiceServer must be embedded to have forward compatible implementations. +type UnimplementedGrpcServiceServer struct { +} + +func (UnimplementedGrpcServiceServer) AnalyzeService(context.Context, *DnsChain) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method AnalyzeService not implemented") +} +func (UnimplementedGrpcServiceServer) mustEmbedUnimplementedGrpcServiceServer() {} + +// UnsafeGrpcServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to GrpcServiceServer will +// result in compilation errors. +type UnsafeGrpcServiceServer interface { + mustEmbedUnimplementedGrpcServiceServer() +} + +func RegisterGrpcServiceServer(s grpc.ServiceRegistrar, srv GrpcServiceServer) { + s.RegisterService(&GrpcService_ServiceDesc, srv) +} + +func _GrpcService_AnalyzeService_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DnsChain) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(GrpcServiceServer).AnalyzeService(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/ohmydns2.plugin.v64DNS.analyze.GrpcService/AnalyzeService", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(GrpcServiceServer).AnalyzeService(ctx, req.(*DnsChain)) + } + return interceptor(ctx, in, info, handler) +} + +// GrpcService_ServiceDesc is the grpc.ServiceDesc for GrpcService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var GrpcService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ohmydns2.plugin.v64DNS.analyze.GrpcService", + HandlerType: (*GrpcServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "AnalyzeService", + Handler: _GrpcService_AnalyzeService_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "analyze.proto", +} diff --git a/plugin/v64dns/analyze/pb/analyze_pb2.py b/plugin/v64dns/analyze/pb/analyze_pb2.py new file mode 100644 index 0000000..81a6b0c --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyze_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: analyze.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ranalyze.proto\x12\x1eohmydns2.plugin.v64DNS.analyze\"S\n\x08\x44nsChain\x12\r\n\x05gtype\x18\x01 \x01(\t\x12\x0c\n\x04guri\x18\x02 \x01(\t\x12\r\n\x05guser\x18\x03 \x01(\t\x12\r\n\x05gpass\x18\x04 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x05 \x03(\t\"\x15\n\x06result\x12\x0b\n\x03res\x18\x01 \x01(\t2q\n\x0bGrpcService\x12\x62\n\x0e\x41nalyzeService\x12(.ohmydns2.plugin.v64DNS.analyze.DnsChain\x1a&.ohmydns2.plugin.v64DNS.analyze.resultB\x06Z\x04.;pbb\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'analyze_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'Z\004.;pb' + _DNSCHAIN._serialized_start=49 + _DNSCHAIN._serialized_end=132 + _RESULT._serialized_start=134 + _RESULT._serialized_end=155 + _GRPCSERVICE._serialized_start=157 + _GRPCSERVICE._serialized_end=270 +# @@protoc_insertion_point(module_scope) diff --git a/plugin/v64dns/analyze/pb/analyze_pb2.pyi b/plugin/v64dns/analyze/pb/analyze_pb2.pyi new file mode 100644 index 0000000..a03531d --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyze_pb2.pyi @@ -0,0 +1,26 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class DnsChain(_message.Message): + __slots__ = ["data", "gpass", "gtype", "guri", "guser"] + DATA_FIELD_NUMBER: _ClassVar[int] + GPASS_FIELD_NUMBER: _ClassVar[int] + GTYPE_FIELD_NUMBER: _ClassVar[int] + GURI_FIELD_NUMBER: _ClassVar[int] + GUSER_FIELD_NUMBER: _ClassVar[int] + data: _containers.RepeatedScalarFieldContainer[str] + gpass: str + gtype: str + guri: str + guser: str + def __init__(self, gtype: _Optional[str] = ..., guri: _Optional[str] = ..., guser: _Optional[str] = ..., gpass: _Optional[str] = ..., data: _Optional[_Iterable[str]] = ...) -> None: ... + +class result(_message.Message): + __slots__ = ["res"] + RES_FIELD_NUMBER: _ClassVar[int] + res: str + def __init__(self, res: _Optional[str] = ...) -> None: ... diff --git a/plugin/v64dns/analyze/pb/analyze_pb2_grpc.py b/plugin/v64dns/analyze/pb/analyze_pb2_grpc.py new file mode 100644 index 0000000..d98bc89 --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyze_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import analyze_pb2 as analyze__pb2 + + +class GrpcServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.AnalyzeService = channel.unary_unary( + '/ohmydns2.plugin.v64DNS.analyze.GrpcService/AnalyzeService', + request_serializer=analyze__pb2.DnsChain.SerializeToString, + response_deserializer=analyze__pb2.result.FromString, + ) + + +class GrpcServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def AnalyzeService(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GrpcServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'AnalyzeService': grpc.unary_unary_rpc_method_handler( + servicer.AnalyzeService, + request_deserializer=analyze__pb2.DnsChain.FromString, + response_serializer=analyze__pb2.result.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'ohmydns2.plugin.v64DNS.analyze.GrpcService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class GrpcService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def AnalyzeService(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/ohmydns2.plugin.v64DNS.analyze.GrpcService/AnalyzeService', + analyze__pb2.DnsChain.SerializeToString, + analyze__pb2.result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/plugin/v64dns/analyze/pb/analyzedutil.py b/plugin/v64dns/analyze/pb/analyzedutil.py new file mode 100644 index 0000000..071583e --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyzedutil.py @@ -0,0 +1,69 @@ +# analyzeutl 是分析器工具集合 +import awdb + + +# 过滤所有空值 +def filterNull(s): + if s == "": + return "UnKnown" + return s + + +def IP46(IP: str): + if ':' in IP: + return "v6" + if '.' in IP: + return "v4" + return "Unknown" + + +path_ip4app = "./data/IP_scene_all_cn.awdb" +path_ip6 = "./data/IP_city_single_BD09_WGS84_ipv6_en.awdb" +path_ip4qvxian = "./data/IP_basic_single_WGS84_en.awdb" + + +# 实例化数据读取器 +def makereader(arg=0): + # 默认加载所有离线数据 + dloader_ip4app = awdb.open_database(path_ip4app) + dloader_ip6 = awdb.open_database(path_ip6) + dloader_ip4qx = awdb.open_database(path_ip4qvxian) + return dloader_ip4app, dloader_ip4qx, dloader_ip6 + + +reader_ip4app, reader_ip4qx, reader_ip6 = makereader() + + +# # 传入一个ip地址,根据对象中的数据进行IP关联分析 +# def make_IPinfo(ip): +# record=getrecord(ip) +# cou = record.get('areacode', b'').decode("utf-8") +# multiarea = record.get('multiAreas', {}) +# if multiarea: +# print("有多组记录") +# else: +# pass +# return record + +# 返回IP离线库中与ip相关的记录,记录中不包含应用场景 +def getrecord(ip): + if (IP46(ip) == "v4"): + return IP4_info(ip) + elif (IP46(ip) == "v6"): + return IP6_info(ip) + else: + print("地址存在问题") + print(ip) + return 1 + + +# 返回IPv4记录 +def IP4_info(ip): + (record, prefix_len) = reader_ip4qx.get_with_prefix_len(ip) + return record + + +# 返回IPv6记录 +def IP6_info(ip): + (record, prefix_len) = reader_ip6.get_with_prefix_len(ip) + return record diff --git a/plugin/v64dns/analyze/pb/analyzer.py b/plugin/v64dns/analyze/pb/analyzer.py new file mode 100644 index 0000000..ff21d6b --- /dev/null +++ b/plugin/v64dns/analyze/pb/analyzer.py @@ -0,0 +1,201 @@ +# !coding=utf-8 +import datetime +import logging +import time +from concurrent import futures + +import grpc +import pytz +from neomodel import db, StringProperty, DateTimeFormatProperty, RelationshipTo, StructuredRel, IntegerProperty, \ + StructuredNode, config, BooleanProperty + +import analyze_pb2 +import analyze_pb2_grpc +import analyzedutil as aul + +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + +class node: + ip = "" + AS = "" + next = "" + isp = "" + cou = "" + couCode = "" + prov = "" + lat = "" + lng = "" + FindTime = "" + dataOK = "" + owner = "" + + def __init__(self, ip): + self.ip = ip + record = aul.getrecord(ip) + if record == 1: + self.dataOK = False + return + self.dataOK = True + self.isp = aul.filterNull(record.get('isp', b'').decode("utf-8")) + self.lat = aul.filterNull(record.get('latwgs', b'').decode("utf-8")) + self.lng = aul.filterNull(record.get('lngwgs', b'').decode("utf-8")) + self.prov = aul.filterNull(record.get('province', b'').decode("utf-8")) + self.AS = aul.filterNull(record.get('asnumber', b'').decode("utf-8")) + self.couCode = aul.filterNull(record.get('areacode', b'').decode("utf-8")) + self.cou = aul.filterNull(record.get('country', b'').decode("utf-8")) + # self.FindTime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + self.FindTime = datetime.datetime.now(pytz.UTC) + self.owner = aul.filterNull(record.get('owner', b'').decode("utf-8")) + + +# 与go之间的通信 + +class RequestServe(analyze_pb2_grpc.GrpcServiceServicer): + graph_conn = "" + + def AnalyzeService(self, request, context): + ''' + 具体实现AnalyzeService服务方法 + :param request: + :param context: + :return: + ''' + r = request + print("receive R!!") + print(r) + if r.gtype == "neo4j": + if self.graph_conn == "": + url = str(r.guri).split("//")[0] + "//" + r.guser + ":" + r.gpass + "@" + str(r.guri).split("//")[1] + self.graph_conn = neo4j_connector(url) + print("已连接到图数据库Neo4j:" + r.guri) + print(self.graph_conn) + result = self.graph_conn.work_with_neoj_53(r.data) + return analyze_pb2.result(res=result) + return analyze_pb2.result(res="not support") + + +working_addr = "127.0.0.1" +working_port = "56789" + + +def serve(): + # 启动 rpc 服务,这里可定义最大接收和发送大小(单位M),默认只有4M + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=[ + ('grpc.max_send_message_length', 100 * 1024 * 1024), + ('grpc.max_receive_message_length', 100 * 1024 * 1024)]) + + analyze_pb2_grpc.add_GrpcServiceServicer_to_server(RequestServe(), server) + server.add_insecure_port(working_addr + ":" + working_port) + server.start() + print("Python分析模块启动,工作在 " + working_addr + ":" + working_port) + try: + while True: + time.sleep(60 * 60 * 24) # one day in seconds + except KeyboardInterrupt: + server.stop(0) + + +class RelResolver53(StructuredRel): + W = IntegerProperty() + LTIME = DateTimeFormatProperty(default_now=True, format="%Y-%m-%d %H:%M:%S") + + +class NodeResolver53(StructuredNode): + IP = StringProperty(required=True, unique_index=True) + AS = StringProperty() + ISP = StringProperty() + COU = StringProperty() + CCODE = StringProperty() + PROV = StringProperty() + LAT = StringProperty() + LNG = StringProperty() + IPType = StringProperty() + FTIME = DateTimeFormatProperty(format="%Y-%m-%d %H:%M:%S") + LTIME = DateTimeFormatProperty(default_now=True, format="%Y-%m-%d %H:%M:%S") + W = IntegerProperty() + ISPUBLIC = BooleanProperty(default=False) + LINK = RelationshipTo("NodeResolver53", "IP_LINK", model=RelResolver53) + + +class neo4j_connector: + graph = "" + + # nodematcher = "" + # relatmatcher = "" + + def __init__(self, url): + # 连接neo4j + #config.ENCRYPTED = True + print(url) + config.DATABASE_URL =url + db.set_connection(url) + # self.graph = Graph(guri, auth=(guser, gpass), name="neo4j") + # self.nodematcher = NodeMatcher(self.graph) + # self.relatmatcher = RelationshipMatcher(self.graph) + + def work_with_neoj_53(self, data): + for d in range(len(data) - 1): + n = node(data[d]) + if not n.dataOK: + return "node err because ip" + # 查询是否存在节点 + nd, exist = self.checknode_neo4j(n.ip) + # 不存在则新建 + if not exist: + nd = NodeResolver53(AS=n.AS, COU=n.cou, + CCODE=n.couCode, LAT=n.lat, LNG=n.lng, + ISP=n.isp, IPType=aul.IP46(n.ip), PROV=n.prov, FTIME=n.FindTime, + LTIME=n.FindTime, IP=n.ip, W=1) + if data[2] == "0" and d == 0: + nd.ISPUBLIC = True + nd.save() + # 存在则只修改时间 + else: + # nd.LTIME = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + nd.LTIME = datetime.datetime.now(pytz.UTC) + if nd.W is not None: + nd.W += 1 + else: + nd.W = 1 + nd.save() + + # 查询是否存在关系 + L, lexist = self.checklink_neo4j(data[0], data[1]) + # 数据存在问题则退出 + if L == "Err": + return "node err when link" + # 不存在则建立关联 + if not lexist: + L[0].LINK.connect(L[1], {'W': 1, 'LTIME': datetime.datetime.now(pytz.UTC)}).save() + # relates.append( + # Relationship(nodes[i], 'IP_link', nodes[i + 1], TIME=time.time(), LTIME=time.time(), W=1)) + # 存在则修改权重 + else: + L.W += 1 + # L.LTIME = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + L.LTIME = datetime.datetime.now(pytz.UTC) + L.save() + # 提交链接 + return "success" + + def checknode_neo4j(self, ip): + a = NodeResolver53.nodes.get_or_none(IP=ip) + if a is not None: + return a, True + return None, False + + def checklink_neo4j(self, ip_from, ip_to): + f = NodeResolver53.nodes.get_or_none(IP=ip_from) + t = NodeResolver53.nodes.get_or_none(IP=ip_to) + if f is None or t is None: + return "Err", False + rel = f.LINK.relationship(t) + if rel is not None: + return rel, True + return [f, t], False + + +if __name__ == '__main__': + serve() diff --git a/plugin/v64dns/analyze/pb/awdb/__init__.py b/plugin/v64dns/analyze/pb/awdb/__init__.py new file mode 100644 index 0000000..862b6b1 --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/__init__.py @@ -0,0 +1,34 @@ +import awdb.reader + +try: + import awdb.extension +except ImportError: + awdb.extension = None + +from awdb.const import (MODE_AUTO, MODE_MMAP, MODE_MMAP_EXT, MODE_FILE, + MODE_MEMORY, MODE_FD) + + +def open_database(database, mode=MODE_AUTO): + has_extension = awdb.extension and hasattr(awdb.extension, + 'Reader') + if (mode == MODE_AUTO and has_extension) or mode == MODE_MMAP_EXT: + if not has_extension: + raise ValueError( + "MODE_MMAP_EXT requires the awdb.extension module to be available" + ) + return awdb.extension.Reader(database) + if mode in (MODE_AUTO, MODE_MMAP, MODE_FILE, MODE_MEMORY, MODE_FD): + return awdb.reader.Reader(database, mode) + raise ValueError('Unsupported open mode: {0}'.format(mode)) + + +def Reader(database): + return open_database(database) + + +__title__ = 'awdb' +__version__ = '1.5.2' +__author__ = '' +__license__ = 'Apache License, Version 2.0' +__copyright__ = 'Copyright 2013-2020 AW, Inc.' diff --git a/plugin/v64dns/analyze/pb/awdb/compat.py b/plugin/v64dns/analyze/pb/awdb/compat.py new file mode 100644 index 0000000..9952ac4 --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/compat.py @@ -0,0 +1,39 @@ +import sys + +import ipaddress + + +if sys.version_info[0] == 2: + + def compat_ip_address(address): + if isinstance(address, bytes): + address = address.decode() + return ipaddress.ip_address(address) + + int_from_byte = ord + + FileNotFoundError = IOError + + def int_from_bytes(b): + if b: + return int(b.encode("hex"), 16) + return 0 + + byte_from_int = chr + + string_type = basestring + +else: + + def compat_ip_address(address): + return ipaddress.ip_address(address) + + int_from_byte = lambda x: x + + FileNotFoundError = FileNotFoundError + + int_from_bytes = lambda x: int.from_bytes(x, 'big') + + byte_from_int = lambda x: bytes([x]) + + string_type = str diff --git a/plugin/v64dns/analyze/pb/awdb/const.py b/plugin/v64dns/analyze/pb/awdb/const.py new file mode 100644 index 0000000..8618d58 --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/const.py @@ -0,0 +1,6 @@ +MODE_AUTO = 0 +MODE_MMAP_EXT = 1 +MODE_MMAP = 2 +MODE_FILE = 4 +MODE_MEMORY = 8 +MODE_FD = 16 diff --git a/plugin/v64dns/analyze/pb/awdb/decoder.py b/plugin/v64dns/analyze/pb/awdb/decoder.py new file mode 100644 index 0000000..1b43fc8 --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/decoder.py @@ -0,0 +1,168 @@ +from __future__ import unicode_literals + +import struct + +from awdb.compat import byte_from_int, int_from_byte, int_from_bytes +from awdb.errors import InvalidDatabaseError + + +class Decoder(object): + def __init__(self, database_buffer, pointer_base=0, pointer_test=False): + self._pointer_test = pointer_test + self._buffer = database_buffer + self._pointer_base = pointer_base + + def _decode_array(self, size, offset): + array = [] + for _ in range(size): + (value, offset) = self.decode(offset) + array.append(value) + return array, offset + + def _decode_boolean(self, size, offset): + return size != 0, offset + + def _decode_bytes(self, size, offset): + new_offset = offset + size + return self._buffer[offset:new_offset], new_offset + + def _decode_double(self, size, offset): + self._verify_size(size, 8) + new_offset = offset + size + packed_bytes = self._buffer[offset:new_offset] + (value, ) = struct.unpack(b'!d', packed_bytes) + return value, new_offset + + def _decode_float(self, size, offset): + self._verify_size(size, 4) + new_offset = offset + size + packed_bytes = self._buffer[offset:new_offset] + (value, ) = struct.unpack(b'!f', packed_bytes) + return value, new_offset + + def _decode_int32(self, size, offset): + if size == 0: + return 0, offset + new_offset = offset + size + packed_bytes = self._buffer[offset:new_offset] + + if size != 4: + packed_bytes = packed_bytes.rjust(4, b'\x00') + (value, ) = struct.unpack(b'!i', packed_bytes) + return value, new_offset + + def _decode_map(self, size, offset): + container = {} + for _ in range(size): + (key, offset) = self.decode(offset) + (value, offset) = self.decode(offset) + if key == value: + container[key] = bytes(value, 'utf-8') + else: + container[key] = value + # print("###") + # print(container) + # print("###") + return container, offset + + def _decode_pointer(self, size, offset): + pointer_size = (size >> 3) + 1 + + buf = self._buffer[offset:offset + pointer_size] + new_offset = offset + pointer_size + + if pointer_size == 1: + buf = byte_from_int(size & 0x7) + buf + pointer = struct.unpack(b'!H', buf)[0] + self._pointer_base + elif pointer_size == 2: + buf = b'\x00' + byte_from_int(size & 0x7) + buf + pointer = struct.unpack(b'!I', buf)[0] + 2048 + self._pointer_base + elif pointer_size == 3: + buf = byte_from_int(size & 0x7) + buf + pointer = struct.unpack(b'!I', + buf)[0] + 526336 + self._pointer_base + else: + pointer = struct.unpack(b'!I', buf)[0] + self._pointer_base + + if self._pointer_test: + return pointer, new_offset + (value, _) = self.decode(pointer) + return value, new_offset + + def _decode_uint(self, size, offset): + new_offset = offset + size + uint_bytes = self._buffer[offset:new_offset] + return int_from_bytes(uint_bytes), new_offset + + def _decode_utf8_string(self, size, offset): + new_offset = offset + size + return self._buffer[offset:new_offset].decode('utf-8'), new_offset + + _type_decoder = { + 1: _decode_pointer, + 2: _decode_utf8_string, + 3: _decode_double, + 4: _decode_bytes, + 5: _decode_uint, + 6: _decode_uint, + 7: _decode_map, + 8: _decode_int32, + 9: _decode_uint, + 10: _decode_uint, + 11: _decode_array, + 14: _decode_boolean, + 15: _decode_float, + } + + def decode(self, offset): + new_offset = offset + 1 + ctrl_byte = int_from_byte(self._buffer[offset]) + type_num = ctrl_byte >> 5 + if not type_num: + (type_num, new_offset) = self._read_extended(new_offset) + + try: + decoder = self._type_decoder[type_num] + except KeyError: + raise InvalidDatabaseError('Unexpected type number ({type}) ' + 'encountered'.format(type=type_num)) + + (size, new_offset) = self._size_from_ctrl_byte(ctrl_byte, new_offset, + type_num) + return decoder(self, size, new_offset) + + def _read_extended(self, offset): + next_byte = int_from_byte(self._buffer[offset]) + type_num = next_byte + 7 + if type_num < 7: + raise InvalidDatabaseError( + 'Something went horribly wrong in the decoder. An ' + 'extended type resolved to a type number < 8 ' + '({type})'.format(type=type_num)) + return type_num, offset + 1 + + def _verify_size(self, expected, actual): + if expected != actual: + raise InvalidDatabaseError( + 'The AW DB file\'s data section contains bad data ' + '(unknown data type or corrupt data)') + + def _size_from_ctrl_byte(self, ctrl_byte, offset, type_num): + size = ctrl_byte & 0x1f + if type_num == 1 or size < 29: + return size, offset + + if size == 29: + size = 29 + int_from_byte(self._buffer[offset]) + return size, offset + 1 + + if size == 30: + new_offset = offset + 2 + size_bytes = self._buffer[offset:new_offset] + size = 285 + struct.unpack(b'!H', size_bytes)[0] + return size, new_offset + + new_offset = offset + 3 + size_bytes = self._buffer[offset:new_offset] + size = struct.unpack(b'!I', b'\x00' + size_bytes)[0] + 65821 + return size, new_offset diff --git a/plugin/v64dns/analyze/pb/awdb/errors.py b/plugin/v64dns/analyze/pb/awdb/errors.py new file mode 100644 index 0000000..5687c9e --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/errors.py @@ -0,0 +1,2 @@ +class InvalidDatabaseError(RuntimeError): + pass diff --git a/plugin/v64dns/analyze/pb/awdb/file.py b/plugin/v64dns/analyze/pb/awdb/file.py new file mode 100644 index 0000000..6d654f2 --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/file.py @@ -0,0 +1,45 @@ +import os + +try: + from multiprocessing import Lock +except ImportError: + from threading import Lock + + +class FileBuffer(object): + def __init__(self, database): + self._handle = open(database, 'rb') + self._size = os.fstat(self._handle.fileno()).st_size + if not hasattr(os, 'pread'): + self._lock = Lock() + + def __getitem__(self, key): + if isinstance(key, slice): + return self._read(key.stop - key.start, key.start) + if isinstance(key, int): + return self._read(1, key)[0] + raise TypeError("Invalid argument type.") + + def rfind(self, needle, start): + pos = self._read(self._size - start - 1, start).rfind(needle) + if pos == -1: + return pos + return start + pos + + def size(self): + return self._size + + def close(self): + self._handle.close() + + if hasattr(os, 'pread'): + + def _read(self, buffersize, offset): + return os.pread(self._handle.fileno(), buffersize, offset) + + else: + + def _read(self, buffersize, offset): + with self._lock: + self._handle.seek(offset) + return self._handle.read(buffersize) diff --git a/plugin/v64dns/analyze/pb/awdb/reader.py b/plugin/v64dns/analyze/pb/awdb/reader.py new file mode 100644 index 0000000..fc91369 --- /dev/null +++ b/plugin/v64dns/analyze/pb/awdb/reader.py @@ -0,0 +1,208 @@ +from __future__ import unicode_literals + +try: + import mmap +except ImportError: + mmap = None + +import struct + +from awdb.compat import compat_ip_address, string_type +from awdb.const import MODE_AUTO, MODE_MMAP, MODE_FILE, MODE_MEMORY, MODE_FD +from awdb.decoder import Decoder +from awdb.errors import InvalidDatabaseError +from awdb.file import FileBuffer + + +class Reader(object): + + _DATA_SECTION_SEPARATOR_SIZE = 16 + _METADATA_START_MARKER = b"\xAB\xCD\xEFipplus360.com" + + _ipv4_start = None + + def __init__(self, database, mode=MODE_AUTO): + if (mode == MODE_AUTO and mmap) or mode == MODE_MMAP: + with open(database, 'rb') as db_file: + self._buffer = mmap.mmap(db_file.fileno(), + 0, + access=mmap.ACCESS_READ) + self._buffer_size = self._buffer.size() + filename = database + elif mode in (MODE_AUTO, MODE_FILE): + self._buffer = FileBuffer(database) + self._buffer_size = self._buffer.size() + filename = database + elif mode == MODE_MEMORY: + with open(database, 'rb') as db_file: + self._buffer = db_file.read() + self._buffer_size = len(self._buffer) + filename = database + elif mode == MODE_FD: + self._buffer = database.read() + self._buffer_size = len(self._buffer) + filename = database.name + else: + raise ValueError( + 'Unsupported open mode ({0}). Only MODE_AUTO, MODE_FILE, ' + 'MODE_MEMORY and MODE_FD are supported by the pure Python ' + 'Reader'.format(mode)) + + metadata_start = self._buffer.rfind( + self._METADATA_START_MARKER, max(0, + self._buffer_size - 128 * 1024)) + + if metadata_start == -1: + self.close() + raise InvalidDatabaseError('Error opening database file ({0}). ' + 'Is this a valid AW DB file?' + ''.format(filename)) + + metadata_start += len(self._METADATA_START_MARKER) + metadata_decoder = Decoder(self._buffer, metadata_start) + (metadata, _) = metadata_decoder.decode(metadata_start) + self._metadata = Metadata(**metadata) + + self._decoder = Decoder( + self._buffer, self._metadata.search_tree_size + + self._DATA_SECTION_SEPARATOR_SIZE) + self.closed = False + + def metadata(self): + return self._metadata + + def get(self, ip_address): + (record, _) = self.get_with_prefix_len(ip_address) + return record + + def get_with_prefix_len(self, ip_address): + if isinstance(ip_address, string_type): + address = compat_ip_address(ip_address) + else: + address = ip_address + + try: + packed_address = bytearray(address.packed) + except AttributeError: + raise TypeError('argument 1 must be a string or ipaddress object') + + if address.version == 6 and self._metadata.ip_version == 4: + raise ValueError( + 'Error looking up {0}. You attempted to look up ' + 'an IPv6 address in an IPv4-only database.'.format(ip_address)) + + (pointer, prefix_len) = self._find_address_in_tree(packed_address) + + if pointer: + return self._resolve_data_pointer(pointer), prefix_len + return None, prefix_len + + def _find_address_in_tree(self, packed): + bit_count = len(packed) * 8 + node = self._start_node(bit_count) + node_count = self._metadata.node_count + + i = 0 + while i < bit_count and node < node_count: + bit = 1 & (packed[i >> 3] >> 7 - (i % 8)) + node = self._read_node(node, bit) + i = i + 1 + + if node == node_count: + return 0, i + if node > node_count: + return node, i + + raise InvalidDatabaseError('Invalid node in search tree') + + def _start_node(self, length): + if self._metadata.ip_version != 6 or length == 128: + return 0 + + if self._ipv4_start: + return self._ipv4_start + + node = 0 + for _ in range(96): + if node >= self._metadata.node_count: + break + node = self._read_node(node, 0) + self._ipv4_start = node + return node + + def _read_node(self, node_number, index): + base_offset = node_number * self._metadata.node_byte_size + + record_size = self._metadata.record_size + if record_size == 24: + offset = base_offset + index * 3 + node_bytes = b'\x00' + self._buffer[offset:offset + 3] + elif record_size == 28: + offset = base_offset + 3 * index + node_bytes = bytearray(self._buffer[offset:offset + 4]) + if index: + node_bytes[0] = 0x0F & node_bytes[0] + else: + middle = (0xF0 & node_bytes.pop()) >> 4 + node_bytes.insert(0, middle) + elif record_size == 32: + offset = base_offset + index * 4 + node_bytes = self._buffer[offset:offset + 4] + else: + raise InvalidDatabaseError( + 'Unknown record size: {0}'.format(record_size)) + return struct.unpack(b'!I', node_bytes)[0] + + def _resolve_data_pointer(self, pointer): + resolved = pointer - self._metadata.node_count + \ + self._metadata.search_tree_size + + if resolved >= self._buffer_size: + raise InvalidDatabaseError( + "The AW DB file's search tree is corrupt") + + (data, _) = self._decoder.decode(resolved) + return data + + def close(self): + if type(self._buffer) not in (str, bytes): + self._buffer.close() + self.closed = True + + def __exit__(self, *args): + self.close() + + def __enter__(self): + if self.closed: + raise ValueError('Attempt to reopen a closed AW DB') + return self + + +class Metadata(object): + def __init__(self, **kwargs): + self.node_count = kwargs['node_count'] + self.record_size = kwargs['record_size'] + self.ip_version = kwargs['ip_version'] + self.database_type = kwargs['database_type'] + self.languages = kwargs['languages'] + self.binary_format_major_version = kwargs[ + 'binary_format_major_version'] + self.binary_format_minor_version = kwargs[ + 'binary_format_minor_version'] + self.build_epoch = kwargs['build_epoch'] + self.description = kwargs['description'] + + @property + def node_byte_size(self): + return self.record_size // 4 + + @property + def search_tree_size(self): + return self.node_count * self.node_byte_size + + def __repr__(self): + args = ', '.join('%s=%r' % x for x in self.__dict__.items()) + return '{module}.{class_name}({data})'.format( + module=self.__module__, + class_name=self.__class__.__name__, + data=args) diff --git a/plugin/v64dns/analyze/pb/twv6.py b/plugin/v64dns/analyze/pb/twv6.py new file mode 100644 index 0000000..85bd0a8 --- /dev/null +++ b/plugin/v64dns/analyze/pb/twv6.py @@ -0,0 +1,32 @@ +import pandas as pd +from rich.progress import Progress + +from analyzer import node + +data = [] +with Progress() as p: + inputdata = "./53openIPv6.txt" + outpath = "./53openIPv6dnsTW.csv" + # inputdata = "./TW_v6resolver.txt" + # outpath = "./TW_v6resolver.csv" + print(len(open(inputdata).readlines())) + task = p.add_task("[blue]Working...", total=sum(1 for _ in open(inputdata))) + with open(inputdata, "r") as f: + while True: + l = f.readline() + if l == "": + break + ip6 = l.splitlines()[0] + n = node(ip6) + p.update(task, advance=1, refresh=True) + if n.prov == "Taiwan": + data.append({"ip": n.ip, + "owner": n.owner, + "isp": n.isp, + "area": n.prov, + "asn": n.AS, + "lat": n.lat, + "lng": n.lng, + "type": "开放解析器"}) +d = pd.DataFrame(data) +d.to_csv(outpath, index=False) diff --git a/plugin/v64dns/setup.go b/plugin/v64dns/setup.go new file mode 100644 index 0000000..a859532 --- /dev/null +++ b/plugin/v64dns/setup.go @@ -0,0 +1,103 @@ +package v64dns + +import ( + "ohmydns2/core/dnsserver" + "ohmydns2/plugin" + olog "ohmydns2/plugin/pkg/log" + "ohmydns2/plugin/pkg/prober" + "ohmydns2/plugin/v64dns/analyze" + "strconv" + + "github.com/coredns/caddy" +) + +func init() { plugin.Register("v64dns", setup) } + +// 读取参数 +func setup(c *caddy.Controller) error { + v64dnsserver, err := parseArg(c) + if err != nil { + log.Error("参数存在问题") + return err + } + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return v64dnsserver + }) + return nil +} + +// 解析参数 +func parseArg(c *caddy.Controller) (*V64dns, error) { + v := new(V64dns) + v.p.maxLen = chain_maxlen + v.p.dmChange = false + + for c.Next() { + arg := c.RemainingArgs() + v.zone = plugin.OriginsFromArgsOrServerBlock(arg, c.ServerBlockKeys)[0] + + for c.NextBlock() { + switch c.Val() { + case "v4ns": + args := c.RemainingArgs() + if len(args) != 2 { + // 缺少参数 + log.Error("v4ns缺少参数,请检查") + } + v.ipv4NS = args[0] + "." + v.zone + v.V4NSAddr = args[1] + case "v6ns": + args := c.RemainingArgs() + if len(args) != 2 { + // 缺少参数 + log.Error("v6ns缺少参数,请检查") + } + v.ipv6NS = args[0] + "." + v.zone + v.V6NSAddr = args[1] + case "chain": + args := c.RemainingArgs() + if len(args) != 2 { + // 缺少参数 + log.Error("域名链缺少参数,请检查") + } + v.p.v4domain = args[0] + "." + v.zone + v.p.v6domain = args[1] + "." + v.zone + v.p.v6subdomain = args[1] + v.p.v4subdomain = args[0] + v.p.dmChange = true + case "ip-embed": + v.p.ipEmbed = true + case "chain-mlen": + if !c.Next() { + return nil, c.ArgErr() + } + v.p.maxLen, _ = strconv.Atoi(c.Val()) + case "analyze": + args := c.RemainingArgs() + if len(args) != 4 { + // 缺少参数 + log.Error("analyze参数错误,请检查") + } + //创建Analyzer对象 + a, err := analyze.NewAnalyzer(args[0], args[1], args[2], args[3]) + if err != nil { + return nil, err + } + v.a = a + } + } + break + } + // 监控停止信号 + //go EL.Stop() + log.Infof("v64权威服务器启动, 工作参数为 \n Zone:%v, NS4:%v, NS6:%v, IPv4子域:%v, IPv6子域:%v", v.zone, v.ipv4NS, v.ipv6NS, v.p.v4subdomain, v.p.v6subdomain) + log.Infof("分析器启动, 工作参数为 \n url:%v", v.a.Graphuri) + log.Infof("测试样例: " + prober.MakeTestProbev64(v.p.v4subdomain, v.zone)) + return v, nil +} + +var log = olog.NewWithPlugin("v64dns") + +const ( + chain_maxlen = 4 +) diff --git a/plugin/v64dns/v64dns.go b/plugin/v64dns/v64dns.go new file mode 100644 index 0000000..2433ddc --- /dev/null +++ b/plugin/v64dns/v64dns.go @@ -0,0 +1,72 @@ +package v64dns + +import ( + "context" + "github.com/miekg/dns" + olog "ohmydns2/plugin/pkg/log" + "ohmydns2/plugin/pkg/request" + "ohmydns2/plugin/v64dns/analyze" + "strings" +) + +// 针对v64dns请求的抽象 +//type v64Request struct { +// eid string // 请求归属的实验ID +//} + +// V64dns代表了水印权威 +type V64dns struct { + zone string + ipv4NS string + ipv6NS string + V4NSAddr string + V6NSAddr string + p Policy // 生成响应的策略 + a *analyze.Analyzer // 分析器配置 +} + +// ServeDNS +// +// @Description: v64dns主要处理程序 +// @receiver d +// @param ctx +// @param w +// @param r +// @return int +// @return error +func (v V64dns) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + // 请求域名 + qname := strings.ToLower(state.QName()) + switch v.VaildRequest(qname) { + case 0: + // 正常请求 + log.Info("Receive:" + state.QName() + "====>" + state.Type()) + msg = v.ResponseHandler(msg, state) + case 2: + log.Info("按照mini请求处理") + // Qname mini + msg = v.ResponseNSorAdd(msg, state, 1) + case 1: + //其他请求不响应 + return 0, nil + } + + err := w.WriteMsg(msg) + if err != nil { + olog.Info(err.Error()) + return dns.RcodeServerFailure, err + } + return 0, nil +} + +// 实现Handler接口 +func (v V64dns) Name() string { + return name +} + +const name = "v64dns" diff --git a/plugin/v64dns/v64dns_policy.go b/plugin/v64dns/v64dns_policy.go new file mode 100644 index 0000000..1e70a03 --- /dev/null +++ b/plugin/v64dns/v64dns_policy.go @@ -0,0 +1,169 @@ +package v64dns + +import ( + "github.com/miekg/dns" + "net" + "ohmydns2/plugin/pkg/request" + "strconv" + "strings" +) + +// Policy 定义了权威生成记录可用的方法 +type Policy struct { + v4domain string + v6domain string + v4subdomain string + v6subdomain string + ipEmbed bool + maxLen int //解析链最大长度 + dmChange bool //是否切换域名 +} + +// ResponseHandler 跨栈解析响应主要处理函数 +func (v V64dns) ResponseHandler(msg *dns.Msg, state request.Request) *dns.Msg { + step, _ := strconv.Atoi(string(rune(state.QName()[1:]))) + + // 未到达最后一步 + if step < v.p.maxLen { + return v.ResponseCNAME(msg, state) + } + // 到达最后一步 + switch state.QType() { + // 只处理TXT,CNAME,AAAA记录 + case dns.TypeTXT, dns.TypeCNAME, dns.TypeAAAA: + return v.ResponseTXT(msg, state) + case dns.TypeNS: + return v.ResponseNSorAdd(msg, state, 0) + default: + log.Info("no") + // 不符合要求的请求类型直接返回空响应 + return msg + } +} + +// ResponseTXT 响应TXT,认为到达最后一步 +func (v V64dns) ResponseTXT(msg *dns.Msg, state request.Request) *dns.Msg { + qname := strings.ToLower(state.QName()) + + // 记录关联关系 + dSlice := strings.Split(strings.ToLower(qname), ".") + if len(dSlice) > 4 { + // 将最近一次编码拆分开 + cList := strings.Split(dSlice[len(dSlice)-6], "-") + iaddr :="" + // 首部编码如果长度为4则为IPv6地址编码,否则为IPv4的 + if len(cList[0]) == 4 { + iaddr = strings.ReplaceAll(dSlice[len(dSlice)-6], "-", ":") + msg = v.SetAuthAdd(msg, state, 4) + } else { + iaddr = strings.ReplaceAll(dSlice[len(dSlice)-6], "-", ".") + msg = v.SetAuthAdd(msg, state, 6) + } + iaddr = net.ParseIP(iaddr).String() + oaddr := state.IP() + // 调用Python + go v.a.Go2py([]string{iaddr, oaddr, "1"}) + + } + + answer := new(dns.TXT) + answer.Txt = []string{state.QName()} + answer.Hdr = dns.RR_Header{Name: state.QName(), Ttl: 3600, Class: dns.ClassINET, Rrtype: dns.TypeTXT} + msg.Answer = append(msg.Answer, answer) + return msg +} + +// ResponseNSorAdd 0返回NS记录应答,1返回胶水记录 +func (v V64dns) ResponseNSorAdd(msg *dns.Msg, state request.Request, flag int) *dns.Msg { + qname := strings.ToLower(state.QName()) + //sub := "" + //for _, qs := range strings.Split(qname, ".") { + // switch v.MatchType(qs) { + // case 2: + // sub = qs + // } + //} + + answer := new(dns.NS) + answer.Hdr.Ttl = 3600 + answer.Hdr.Class = dns.ClassINET + answer.Hdr.Rrtype = dns.TypeNS + answer.Hdr.Name = state.QName() + + if strings.ContainsAny(qname, v.p.v6domain) { + answer.Ns = v.ipv6NS + msg = v.SetAuthAdd(msg, state, 6) + } else { + answer.Ns = v.ipv4NS + msg = v.SetAuthAdd(msg, state, 4) + } + + if flag == 0 { + msg.Answer = append(msg.Answer, answer) + } + + return msg +} + +// ResponseCNAME 响应CNAME记录,生成方式为: +// [进度标识].[随机数].[水印].[子域名].[实验域名] +// ====>[进度标识].[随机数].[水印].<新水印>.<新子域名>.[实验域名] +func (v V64dns) ResponseCNAME(msg *dns.Msg, state request.Request) *dns.Msg { + qname := state.QName() + + dSlice := strings.Split(strings.ToLower(qname), ".") + if len(dSlice) > 4 { + // 将最近一次编码拆分开 + cList := strings.Split(dSlice[len(dSlice)-6], "-") + iaddr := "" + // 首部编码如果长度为4则为IPv6地址编码,否则为IPv4的 + if len(cList[0]) == 4 { + iaddr = strings.ReplaceAll(dSlice[len(dSlice)-6], "-", ":") + msg = v.SetAuthAdd(msg, state, 4) + } else { + iaddr = strings.ReplaceAll(dSlice[len(dSlice)-6], "-", ".") + msg = v.SetAuthAdd(msg, state, 6) + } + iaddr = net.ParseIP(iaddr).String() + oaddr := state.IP() + // 调用Python脚本 + step, _ := strconv.Atoi(string(rune(state.QName()[1]))) + if step == 1 { + go v.a.Go2py([]string{iaddr, oaddr, "0"}) + } else { + go v.a.Go2py([]string{iaddr, oaddr, "1"}) + } + + } + + answer := new(dns.CNAME) + answer.Hdr = dns.RR_Header{ + Name: qname, + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + Ttl: 1, + } + + // 响应内容 + // 实验进度+1 + answer.Target += nextProgress(dSlice[0]) + "." + // 内容填充 + for _, i := range dSlice[1 : len(dSlice)-5] { + answer.Target += i + "." + } + answer.Target += ip2id(state.IP()) + "." + answer.Target += dSlice[len(dSlice)-5] + "." + answer.Target += v.changeSubDomain(dSlice[len(dSlice)-4]) + "." + for _, i := range dSlice[len(dSlice)-3 : len(dSlice)-1] { + answer.Target += i + "." + } + msg.Answer = append(msg.Answer, answer) + log.Debug(answer) + return msg +} + +const ( + v6 = 0 + v4 = 1 + //interval = 20 +) diff --git a/plugin/v64dns/v64dnsutil.go b/plugin/v64dns/v64dnsutil.go new file mode 100644 index 0000000..f0a1db8 --- /dev/null +++ b/plugin/v64dns/v64dnsutil.go @@ -0,0 +1,154 @@ +package v64dns + +import ( + "net" + "ohmydns2/plugin/pkg/request" + "strconv" + "strings" + + "github.com/miekg/dns" +) + +// 判断接收到的域名是否实验可用,0为正常,1代表非实验需要,2代表可能是QName最小化的请求 +func (v V64dns) VaildRequest(d string) int { + //获取到每一级域名的字符串 + ds := strings.Split(d, ".") + //判断是否为目标域名 + if strings.Contains(d, v.zone) { + //判断是否有解析进度,含有c且长度小于4即为进度标识 + if strings.Contains(ds[0], "c") && len(ds[0]) < 4 { + return 0 + } + //不存在解析进度,则可能为Qname最小化的请求,返回2 + return 2 + } + return 1 +} + +// AuthInfo +// +// @Description: 添加权威服务器的授权信息 +// @receiver v +// @param d dns响应 +// @param iptype IP类型,0代表v6,1代表v4 +// @return *dns.Msg +// @return error +func (v *V64dns) AuthInfo(d *dns.Msg, iptype int) (*dns.Msg, error) { + rr := new(dns.NS) + rr.Hdr = dns.RR_Header{Class: dns.ClassINET, Ttl: 3600, Rrtype: dns.TypeNS} + + if iptype == 0 { + rr.Hdr.Name = v.p.v6domain + rr.Ns = v.ipv6NS + } else { + rr.Hdr.Name = v.p.v4domain + rr.Ns = v.ipv4NS + } + d.Ns = append(d.Ns, rr) + return d, nil +} + +// AdditionalInfo +// +// @Description: 添加胶水记录 +// @receiver v +// @param d dns响应 +// @param iptype IP类型,0代表v6,1代表v4 +// @return *dns.Msg +// @return error +func (v *V64dns) AdditionalInfo(d *dns.Msg, state request.Request, iptype int) (*dns.Msg, error) { + // 根据不同的NS返回额外信息 + if iptype == 0 { + dnsadd := new(dns.AAAA) + dnsadd.Hdr = dns.RR_Header{Name: v.ipv6NS, Ttl: 3600, Class: dns.ClassINET, Rrtype: dns.TypeAAAA} + a, _, _ := net.ParseCIDR(v.V6NSAddr + "/64") + //a := net.ParseIP(state.LocalIP()) + dnsadd.AAAA = a.To16() + d.Extra = append(d.Extra, dnsadd) + } else { + dnsadd := new(dns.A) + dnsadd.Hdr = dns.RR_Header{Name: v.ipv4NS, Ttl: 3600, Class: dns.ClassINET, Rrtype: dns.TypeA} + a, _, _ := net.ParseCIDR(v.V4NSAddr + "/24") + //a := net.ParseIP(state.LocalIP()) + dnsadd.A = a.To4() + d.Extra = append(d.Extra, dnsadd) + } + return d, nil +} + +// 根据各级域名特征判断其类型,0为zone,1为水印部分,2为v4/v6子域名,3为随机数,4为进度标识,5为实验对象 +func (v V64dns) MatchType(s string) int { + if strings.Contains(s, "rip") { + return 5 + } + if s == v.p.v4subdomain || s == v.p.v6subdomain { + return 2 + } + if strings.Contains(s, "-") && len(s) > 5 { + return 1 + } + if len(s) == 2 && string(s[0]) == "c" { + return 4 + } + for _, i := range strings.Split(v.zone, ".") { + if s == i { + return 0 + } + } + return 3 +} + +// 切换域名 +func (v V64dns) changeSubDomain(s string) string { + // 切换标志位为false,直接返回原域名 + if !v.p.dmChange { + return s + } + switch s { + case v.p.v4subdomain: + return v.p.v6subdomain + case v.p.v6subdomain: + return v.p.v4subdomain + default: + log.Error("子域名错误") + return "v64" + } +} + +// 修改进度标识 +func nextProgress(s string) string { + n, _ := strconv.Atoi(string(s[1])) + n = n + 1 + return string(s[0]) + strconv.Itoa(n) +} + +// 封装函数,同时添加权威信息和胶水记录,根据子域名切换,domainType=6 or 4 +func (v V64dns) SetAuthAdd(res *dns.Msg, state request.Request, domaintype int) *dns.Msg { + if domaintype == 6 { + res, _ = v.AdditionalInfo(res, state, v6) + res, _ = v.AuthInfo(res, v6) + } else { + res, _ = v.AdditionalInfo(res, state, v4) + res, _ = v.AuthInfo(res, v4) + } + return res +} + +// 从域名中提取eid +func (v V64dns) ExtractEid(qname string) string { + for _, qs := range strings.Split(qname, ".") { + if v.MatchType(qs) == 5 { + return qs + } + } + return "" +} + +func ip2id(ip string) string { + i := ip + if strings.Contains(i, ".") { + return strings.ReplaceAll(i, ".", "-") + } else { + return strings.ReplaceAll(i, ":", "-") + } +} diff --git a/plugin/view/README.md b/plugin/view/README.md new file mode 100644 index 0000000..8522727 --- /dev/null +++ b/plugin/view/README.md @@ -0,0 +1,135 @@ +# view + +## Name + +*view* - defines conditions that must be met for a DNS request to be routed to the server block. + +## Description + +*view* defines an expression that must evaluate to true for a DNS request to be routed to the server block. +This enables advanced server block routing functions such as split dns. + +## Syntax +``` +view NAME { + expr EXPRESSION +} +``` + +* `view` **NAME** - The name of the view used by metrics and exported as metadata for requests that match the + view's expression +* `expr` **EXPRESSION** - CoreDNS will only route incoming queries to the enclosing server block + if the **EXPRESSION** evaluates to true. See the **Expressions** section for available variables and functions. + If multiple instances of view are defined, all **EXPRESSION** must evaluate to true for CoreDNS will only route + incoming queries to the enclosing server block. + +For expression syntax and examples, see the Expressions and Examples sections. + +## Examples + +Implement CIDR based split DNS routing. This will return a different +answer for `test.` depending on client's IP address. It returns ... +* `test. 3600 IN A 1.1.1.1`, for queries with a source address in 127.0.0.0/24 +* `test. 3600 IN A 2.2.2.2`, for queries with a source address in 192.168.0.0/16 +* `test. 3600 IN A 3.3.3.3`, for all others + +``` +. { + view example1 { + expr incidr(client_ip(), '127.0.0.0/24') + } + hosts { + 1.1.1.1 test + } +} + +. { + view example2 { + expr incidr(client_ip(), '192.168.0.0/16') + } + hosts { + 2.2.2.2 test + } +} + +. { + hosts { + 3.3.3.3 test + } +} +``` + +Send all `A` and `AAAA` requests to `10.0.0.6`, and all other requests to `10.0.0.1`. + +``` +. { + view example { + expr type() in ['A', 'AAAA'] + } + forward . 10.0.0.6 +} + +. { + forward . 10.0.0.1 +} +``` + +Send all requests for `abc.*.example.com` (where * can be any number of labels), to `10.0.0.2`, and all other +requests to `10.0.0.1`. +Note that the regex pattern is enclosed in single quotes, and backslashes are escaped with backslashes. + +``` +. { + view example { + expr name() matches '^abc\\..*\\.example\\.com\\.$' + } + forward . 10.0.0.2 +} + +. { + forward . 10.0.0.1 +} +``` + +## Expressions + +To evaluate expressions, *view* uses the antonmedv/expr package (https://github.com/antonmedv/expr). +For example, an expression could look like: +`(type() == 'A' && name() == 'example.com') || client_ip() == '1.2.3.4'`. + +All expressions should be written to evaluate to a boolean value. + +See https://github.com/antonmedv/expr/blob/master/docs/Language-Definition.md as a detailed reference for valid syntax. + +### Available Expression Functions + +In the context of the *view* plugin, expressions can reference DNS query information by using utility +functions defined below. + +#### DNS Query Functions + +* `bufsize() int`: the EDNS0 buffer size advertised in the query +* `class() string`: class of the request (IN, CH, ...) +* `client_ip() string`: client's IP address, for IPv6 addresses these are enclosed in brackets: `[::1]` +* `do() bool`: the EDNS0 DO (DNSSEC OK) bit set in the query +* `id() int`: query ID +* `name() string`: name of the request (the domain name requested) +* `opcode() int`: query OPCODE +* `port() string`: client's port +* `proto() string`: protocol used (tcp or udp) +* `server_ip() string`: server's IP address; for IPv6 addresses these are enclosed in brackets: `[::1]` +* `server_port() string` : server's port +* `size() int`: request size in bytes +* `type() string`: type of the request (A, AAAA, TXT, ...) + +#### Utility Functions + +* `incidr(ip string, cidr string) bool`: returns true if _ip_ is within _cidr_ +* `metadata(label string)` - returns the value for the metadata matching _label_ + +## Metadata + +The view plugin will publish the following metadata, if the *metadata* +plugin is also enabled: + +* `view/name`: the name of the view handling the current request diff --git a/plugin/view/metadata.go b/plugin/view/metadata.go new file mode 100644 index 0000000..6ee9bc0 --- /dev/null +++ b/plugin/view/metadata.go @@ -0,0 +1,16 @@ +package view + +import ( + "context" + + "github.com/coredns/coredns/plugin/metadata" + "github.com/coredns/coredns/request" +) + +// Metadata implements the metadata.Provider interface. +func (v *View) Metadata(ctx context.Context, state request.Request) context.Context { + metadata.SetValueFunc(ctx, "view/name", func() string { + return v.viewName + }) + return ctx +} diff --git a/plugin/view/setup.go b/plugin/view/setup.go new file mode 100644 index 0000000..77ef791 --- /dev/null +++ b/plugin/view/setup.go @@ -0,0 +1,65 @@ +package view + +import ( + "context" + "strings" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/expression" + + "github.com/antonmedv/expr" +) + +func init() { plugin.Register("view", setup) } + +func setup(c *caddy.Controller) error { + cond, err := parse(c) + if err != nil { + return plugin.Error("view", err) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + cond.Next = next + return cond + }) + + return nil +} + +func parse(c *caddy.Controller) (*View, error) { + v := new(View) + + i := 0 + for c.Next() { + i++ + if i > 1 { + return nil, plugin.ErrOnce + } + args := c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + v.viewName = args[0] + + for c.NextBlock() { + switch c.Val() { + case "expr": + args := c.RemainingArgs() + prog, err := expr.Compile(strings.Join(args, " "), expr.Env(expression.DefaultEnv(context.Background(), nil)), expr.DisableBuiltin("type")) + if err != nil { + return v, err + } + v.progs = append(v.progs, prog) + if err != nil { + return nil, err + } + continue + default: + return nil, c.Errf("unknown property '%s'", c.Val()) + } + } + } + return v, nil +} diff --git a/plugin/view/setup_test.go b/plugin/view/setup_test.go new file mode 100644 index 0000000..7c78380 --- /dev/null +++ b/plugin/view/setup_test.go @@ -0,0 +1,38 @@ +package view + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + progCount int + }{ + {"view example {\n expr name() == 'example.com.'\n}", false, 1}, + {"view example {\n expr incidr(client_ip(), '10.0.0.0/24')\n}", false, 1}, + {"view example {\n expr name() == 'example.com.'\n expr name() == 'example2.com.'\n}", false, 2}, + {"view", true, 0}, + {"view example {\n expr invalid expression\n}", true, 0}, + } + + for i, test := range tests { + v, err := parse(caddy.NewTestController("dns", test.input)) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found none for input %s", i, test.input) + } + if err != nil && !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + if test.shouldErr { + continue + } + if test.progCount != len(v.progs) { + t.Errorf("Test %d: Expected prog length %d, but got %d for %s.", i, test.progCount, len(v.progs), test.input) + } + } +} diff --git a/plugin/view/view.go b/plugin/view/view.go new file mode 100644 index 0000000..448a63a --- /dev/null +++ b/plugin/view/view.go @@ -0,0 +1,48 @@ +package view + +import ( + "context" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/pkg/expression" + "github.com/coredns/coredns/request" + + "github.com/antonmedv/expr" + "github.com/antonmedv/expr/vm" + "github.com/miekg/dns" +) + +// View is a plugin that enables configuring expression based advanced routing +type View struct { + progs []*vm.Program + viewName string + Next plugin.Handler +} + +// Filter implements dnsserver.Viewer. It returns true if all View rules evaluate to true for the given state. +func (v *View) Filter(ctx context.Context, state *request.Request) bool { + env := expression.DefaultEnv(ctx, state) + for _, prog := range v.progs { + result, err := expr.Run(prog, env) + if err != nil { + return false + } + if b, ok := result.(bool); ok && b { + continue + } + // anything other than a boolean true result is considered false + return false + } + return true +} + +// ViewName implements dnsserver.Viewer. It returns the view name +func (v *View) ViewName() string { return v.viewName } + +// Name implements the Handler interface +func (*View) Name() string { return "view" } + +// ServeDNS implements the Handler interface. +func (v *View) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + return plugin.NextOrFailure(v.Name(), v.Next, ctx, w, r) +} diff --git a/plugin/whoami/README.md b/plugin/whoami/README.md new file mode 100644 index 0000000..55d0388 --- /dev/null +++ b/plugin/whoami/README.md @@ -0,0 +1,58 @@ +# whoami + +## Name + +*whoami* - returns your resolver's local IP address, port and transport. + +## Description + +The *whoami* plugin is not really that useful, but can be used for having a simple (fast) endpoint +to test clients against. When *whoami* returns a response it will have your client's IP address in +the additional section as either an A or AAAA record. + +The reply always has an empty answer section. The port and transport are included in the additional +section as a SRV record, transport can be "tcp" or "udp". + +~~~ txt +._<transport>.qname. 0 IN SRV 0 0 <port> . +~~~ + +The *whoami* plugin will respond to every A or AAAA query, regardless of the query name. + +If CoreDNS can't find a Corefile on startup this is the _default_ plugin that gets loaded. As such +it can be used to check that CoreDNS is responding to queries. Other than that this plugin is of +limited use in production. + +## Syntax + +~~~ txt +whoami +~~~ + +## Examples + +Start a server on the default port and load the *whoami* plugin. + +~~~ corefile +example.org { + whoami +} +~~~ + +When queried for "example.org A", CoreDNS will respond with: + +~~~ txt +;; QUESTION SECTION: +;example.org. IN A + +;; ADDITIONAL SECTION: +example.org. 0 IN A 10.240.0.1 +_udp.example.org. 0 IN SRV 0 0 40212 +~~~ + +## See Also + +[Read the blog post][blog] on how this plugin is built, or [explore the source code][code]. + +[blog]: https://coredns.io/2017/03/01/how-to-add-plugins-to-coredns/ +[code]: https://github.com/coredns/coredns/blob/master/plugin/whoami/ diff --git a/plugin/whoami/fuzz.go b/plugin/whoami/fuzz.go new file mode 100644 index 0000000..0525398 --- /dev/null +++ b/plugin/whoami/fuzz.go @@ -0,0 +1,13 @@ +//go:build gofuzz + +package whoami + +import ( + "github.com/coredns/coredns/plugin/pkg/fuzz" +) + +// Fuzz fuzzes cache. +func Fuzz(data []byte) int { + w := Whoami{} + return fuzz.Do(w, data) +} diff --git a/plugin/whoami/log_test.go b/plugin/whoami/log_test.go new file mode 100644 index 0000000..460c11c --- /dev/null +++ b/plugin/whoami/log_test.go @@ -0,0 +1,5 @@ +package whoami + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/plugin/whoami/setup.go b/plugin/whoami/setup.go new file mode 100644 index 0000000..1602740 --- /dev/null +++ b/plugin/whoami/setup.go @@ -0,0 +1,22 @@ +package whoami + +import ( + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" +) + +func init() { plugin.Register("whoami", setup) } + +func setup(c *caddy.Controller) error { + c.Next() // 'whoami' + if c.NextArg() { + return plugin.Error("whoami", c.ArgErr()) + } + + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + return Whoami{} + }) + + return nil +} diff --git a/plugin/whoami/setup_test.go b/plugin/whoami/setup_test.go new file mode 100644 index 0000000..18a5b94 --- /dev/null +++ b/plugin/whoami/setup_test.go @@ -0,0 +1,19 @@ +package whoami + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestSetup(t *testing.T) { + c := caddy.NewTestController("dns", `whoami`) + if err := setup(c); err != nil { + t.Fatalf("Expected no errors, but got: %v", err) + } + + c = caddy.NewTestController("dns", `whoami example.org`) + if err := setup(c); err == nil { + t.Fatalf("Expected errors, but got: %v", err) + } +} diff --git a/plugin/whoami/whoami.go b/plugin/whoami/whoami.go new file mode 100644 index 0000000..b46736c --- /dev/null +++ b/plugin/whoami/whoami.go @@ -0,0 +1,60 @@ +// Package whoami implements a plugin that returns details about the resolving +// querying it. +package whoami + +import ( + "context" + "net" + "strconv" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +const name = "whoami" + +// Whoami is a plugin that returns your IP address, port and the protocol used for connecting +// to CoreDNS. +type Whoami struct{} + +// ServeDNS implements the plugin.Handler interface. +func (wh Whoami) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + a := new(dns.Msg) + a.SetReply(r) + a.Authoritative = true + + ip := state.IP() + var rr dns.RR + + switch state.Family() { + case 1: + rr = new(dns.A) + rr.(*dns.A).Hdr = dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: state.QClass()} + rr.(*dns.A).A = net.ParseIP(ip).To4() + case 2: + rr = new(dns.AAAA) + rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeAAAA, Class: state.QClass()} + rr.(*dns.AAAA).AAAA = net.ParseIP(ip) + } + + srv := new(dns.SRV) + srv.Hdr = dns.RR_Header{Name: "_" + state.Proto() + "." + state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass()} + if state.QName() == "." { + srv.Hdr.Name = "_" + state.Proto() + state.QName() + } + port, _ := strconv.ParseUint(state.Port(), 10, 16) + srv.Port = uint16(port) + srv.Target = "." + + a.Extra = []dns.RR{rr, srv} + + w.WriteMsg(a) + + return 0, nil +} + +// Name implements the Handler interface. +func (wh Whoami) Name() string { return name } diff --git a/plugin/whoami/whoami_test.go b/plugin/whoami/whoami_test.go new file mode 100644 index 0000000..fa6a6f0 --- /dev/null +++ b/plugin/whoami/whoami_test.go @@ -0,0 +1,81 @@ +package whoami + +import ( + "context" + "testing" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestWhoami(t *testing.T) { + wh := Whoami{} + if wh.Name() != name { + t.Errorf("expected plugin name: %s, got %s", wh.Name(), name) + } + tests := []struct { + qname string + qtype uint16 + remote string + expectedCode int + expectedReply []string // ownernames for the records in the additional section. + expectedErr error + }{ + { + qname: "example.org", + qtype: dns.TypeA, + expectedCode: dns.RcodeSuccess, + expectedReply: []string{"example.org.", "_udp.example.org."}, + expectedErr: nil, + }, + // Case insensitive and case preserving + { + qname: "Example.ORG", + qtype: dns.TypeA, + expectedCode: dns.RcodeSuccess, + expectedReply: []string{"Example.ORG.", "_udp.Example.ORG."}, + expectedErr: nil, + }, + { + qname: "example.org", + qtype: dns.TypeA, + remote: "2003::1/64", + expectedCode: dns.RcodeSuccess, + expectedReply: []string{"example.org.", "_udp.example.org."}, + expectedErr: nil, + }, + { + qname: "Example.ORG", + qtype: dns.TypeA, + remote: "2003::1/64", + expectedCode: dns.RcodeSuccess, + expectedReply: []string{"Example.ORG.", "_udp.Example.ORG."}, + expectedErr: nil, + }, + } + + ctx := context.TODO() + + for i, tc := range tests { + req := new(dns.Msg) + req.SetQuestion(dns.Fqdn(tc.qname), tc.qtype) + rec := dnstest.NewRecorder(&test.ResponseWriter{RemoteIP: tc.remote}) + code, err := wh.ServeDNS(ctx, rec, req) + if err != tc.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", i, tc.expectedErr, err) + } + if code != tc.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", i, tc.expectedCode, code) + } + if len(tc.expectedReply) != 0 { + for i, expected := range tc.expectedReply { + actual := rec.Msg.Extra[i].Header().Name + if actual != expected { + t.Errorf("Test %d: Expected answer %s, but got %s", i, expected, actual) + } + } + } + } +} |
