move hardcodedRecords into Xip struct for instance isolation

- replace global hardcodedRecords/mutex with instance fields
- add initialRecords() factory for fresh record copies per instance
- rename initHardcodedRecords to initNameServers, pass nameservers explicitly
- add TestInstanceIsolation to verify instances don't share state
- fix unassigned err in certs/persistFiles
This commit is contained in:
m5r
2026-01-18 10:51:07 +01:00
parent aac894ae6f
commit abb97cce56
4 changed files with 131 additions and 63 deletions

View File

@@ -6,6 +6,7 @@ import (
"os"
"regexp"
"strings"
"sync"
"time"
"github.com/miekg/dns"
@@ -18,6 +19,8 @@ type Xip struct {
domain string
email string
dnsPort uint
recordsMu sync.RWMutex
records map[string]hardcodedRecord
}
type Option func(*Xip)
@@ -42,13 +45,15 @@ func WithDnsPort(port uint) Option {
func WithNameServers(nameServers []string) Option {
return func(x *Xip) {
x.recordsMu.Lock()
defer x.recordsMu.Unlock()
for i, ns := range nameServers {
name := fmt.Sprintf("ns%d.%s.", i+1, x.domain)
ip := net.ParseIP(ns)
entry := hardcodedRecords[name]
entry := x.records[name]
entry.A = append(entry.A, ip)
hardcodedRecords[name] = entry
x.records[name] = entry
x.nameServers = append(x.nameServers, name)
}
@@ -69,9 +74,11 @@ func (xip *Xip) SetTXTRecord(fqdn string, value string) {
return
}
if rootRecords, ok := hardcodedRecords[fqdn]; ok {
xip.recordsMu.Lock()
defer xip.recordsMu.Unlock()
if rootRecords, ok := xip.records[fqdn]; ok {
rootRecords.TXT = []string{value}
hardcodedRecords[fmt.Sprintf("_acme-challenge.%s.", xip.domain)] = rootRecords
xip.records[fmt.Sprintf("_acme-challenge.%s.", xip.domain)] = rootRecords
}
}
@@ -82,18 +89,23 @@ func (xip *Xip) UnsetTXTRecord(fqdn string) {
return
}
if rootRecords, ok := hardcodedRecords[fqdn]; ok {
xip.recordsMu.Lock()
defer xip.recordsMu.Unlock()
if rootRecords, ok := xip.records[fqdn]; ok {
rootRecords.TXT = []string{}
hardcodedRecords[fmt.Sprintf("_acme-challenge.%s.", xip.domain)] = rootRecords
xip.records[fmt.Sprintf("_acme-challenge.%s.", xip.domain)] = rootRecords
}
}
func (xip *Xip) fqdnToA(fqdn string) []*dns.A {
normalizedFqdn := strings.ToLower(fqdn)
if hardcodedRecords[normalizedFqdn].A != nil {
xip.recordsMu.RLock()
records := xip.records[normalizedFqdn].A
xip.recordsMu.RUnlock()
if records != nil {
var aRecords []*dns.A
for _, record := range hardcodedRecords[normalizedFqdn].A {
for _, record := range records {
aRecords = append(aRecords, &dns.A{
Hdr: dns.RR_Header{
Ttl: uint32((time.Minute * 5).Seconds()),
@@ -154,12 +166,15 @@ func (xip *Xip) handleA(question dns.Question, message *dns.Msg) {
func (xip *Xip) handleAAAA(question dns.Question, message *dns.Msg) {
fqdn := question.Name
normalizedFqdn := strings.ToLower(fqdn)
if hardcodedRecords[normalizedFqdn].AAAA == nil {
xip.recordsMu.RLock()
records := xip.records[normalizedFqdn].AAAA
xip.recordsMu.RUnlock()
if records == nil {
xip.answerWithAuthority(question, message)
return
}
for _, record := range hardcodedRecords[normalizedFqdn].AAAA {
for _, record := range records {
message.Answer = append(message.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Ttl: uint32((time.Minute * 5).Seconds()),
@@ -209,12 +224,15 @@ func chunkBy(str string, chunkSize int) (chunks []string) {
func (xip *Xip) handleTXT(question dns.Question, message *dns.Msg) {
fqdn := question.Name
normalizedFqdn := strings.ToLower(fqdn)
if hardcodedRecords[normalizedFqdn].TXT == nil {
xip.recordsMu.RLock()
records := xip.records[normalizedFqdn].TXT
xip.recordsMu.RUnlock()
if records == nil {
xip.answerWithAuthority(question, message)
return
}
for _, record := range hardcodedRecords[normalizedFqdn].TXT {
for _, record := range records {
message.Answer = append(message.Answer, &dns.TXT{
Hdr: dns.RR_Header{
Ttl: uint32((time.Minute * 5).Seconds()),
@@ -230,12 +248,15 @@ func (xip *Xip) handleTXT(question dns.Question, message *dns.Msg) {
func (xip *Xip) handleMX(question dns.Question, message *dns.Msg) {
fqdn := question.Name
normalizedFqdn := strings.ToLower(fqdn)
if hardcodedRecords[normalizedFqdn].MX == nil {
xip.recordsMu.RLock()
records := xip.records[normalizedFqdn].MX
xip.recordsMu.RUnlock()
if records == nil {
xip.answerWithAuthority(question, message)
return
}
for _, record := range hardcodedRecords[normalizedFqdn].MX {
for _, record := range records {
message.Answer = append(message.Answer, &dns.MX{
Hdr: dns.RR_Header{
Ttl: uint32((time.Minute * 5).Seconds()),
@@ -252,12 +273,15 @@ func (xip *Xip) handleMX(question dns.Question, message *dns.Msg) {
func (xip *Xip) handleCNAME(question dns.Question, message *dns.Msg) {
fqdn := question.Name
normalizedFqdn := strings.ToLower(fqdn)
if hardcodedRecords[normalizedFqdn].CNAME == nil {
xip.recordsMu.RLock()
records := xip.records[normalizedFqdn].CNAME
xip.recordsMu.RUnlock()
if records == nil {
xip.answerWithAuthority(question, message)
return
}
for _, record := range hardcodedRecords[normalizedFqdn].CNAME {
for _, record := range records {
message.Answer = append(message.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Ttl: uint32((time.Minute * 5).Seconds()),
@@ -273,7 +297,10 @@ func (xip *Xip) handleCNAME(question dns.Question, message *dns.Msg) {
func (xip *Xip) handleSRV(question dns.Question, message *dns.Msg) {
fqdn := question.Name
normalizedFqdn := strings.ToLower(fqdn)
if hardcodedRecords[normalizedFqdn].SRV == nil {
xip.recordsMu.RLock()
record := xip.records[normalizedFqdn].SRV
xip.recordsMu.RUnlock()
if record == nil {
xip.answerWithAuthority(question, message)
return
}
@@ -285,10 +312,10 @@ func (xip *Xip) handleSRV(question dns.Question, message *dns.Msg) {
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
},
Priority: hardcodedRecords[normalizedFqdn].SRV.Priority,
Weight: hardcodedRecords[normalizedFqdn].SRV.Weight,
Port: hardcodedRecords[normalizedFqdn].SRV.Port,
Target: hardcodedRecords[normalizedFqdn].SRV.Target,
Priority: record.Priority,
Weight: record.Weight,
Port: record.Port,
Target: record.Target,
})
}
@@ -398,7 +425,7 @@ func (xip *Xip) StartServer() {
err := xip.server.ListenAndServe()
defer xip.server.Shutdown()
if err != nil {
utils.Logger.Fatal().Err(err).Msg("Failed to start DNS server")
utils.Logger.Error().Err(err).Msg("Failed to start DNS server")
if strings.Contains(err.Error(), "fly-global-services: no such host") {
// we're not running on fly, bind to 0.0.0.0 instead
port := strings.Split(xip.server.Addr, ":")[1]
@@ -416,26 +443,27 @@ func (xip *Xip) StartServer() {
utils.Logger.Info().Str("dns_address", xip.server.Addr).Msg("Starting up DNS server")
}
func (xip *Xip) initHardcodedRecords() {
config := utils.GetConfig()
func (xip *Xip) initNameServers(nameServers []string) {
rootDomainARecords := []net.IP{}
for i, ns := range config.NameServers {
name := fmt.Sprintf("ns%d.%s.", i+1, config.Domain)
xip.recordsMu.Lock()
defer xip.recordsMu.Unlock()
for i, ns := range nameServers {
name := fmt.Sprintf("ns%d.%s.", i+1, xip.domain)
ip := net.ParseIP(ns)
rootDomainARecords = append(rootDomainARecords, ip)
entry := hardcodedRecords[name]
entry.A = append(hardcodedRecords[name].A, ip)
hardcodedRecords[name] = entry
entry := xip.records[name]
entry.A = append(xip.records[name].A, ip)
xip.records[name] = entry
xip.nameServers = append(xip.nameServers, name)
}
hardcodedRecords[fmt.Sprintf("%s.", config.Domain)] = hardcodedRecord{A: rootDomainARecords}
xip.records[fmt.Sprintf("%s.", xip.domain)] = hardcodedRecord{A: rootDomainARecords}
// will be filled in later when requesting certificates
hardcodedRecords[fmt.Sprintf("_acme-challenge.%s.", config.Domain)] = hardcodedRecord{TXT: []string{}}
xip.records[fmt.Sprintf("_acme-challenge.%s.", xip.domain)] = hardcodedRecord{TXT: []string{}}
}
func NewXip(opts ...Option) (xip *Xip) {
@@ -444,6 +472,7 @@ func NewXip(opts ...Option) (xip *Xip) {
domain: config.Domain,
email: config.Email,
dnsPort: config.DnsPort,
records: initialRecords(),
}
for _, opt := range opts {
@@ -451,7 +480,7 @@ func NewXip(opts ...Option) (xip *Xip) {
}
if len(xip.nameServers) == 0 {
xip.initHardcodedRecords()
xip.initNameServers(config.NameServers)
}
xip.server = dns.Server{