diff --git a/acmetxt.go b/acmetxt.go index 63454a6..e7ba7ea 100644 --- a/acmetxt.go +++ b/acmetxt.go @@ -12,6 +12,7 @@ import ( type ACMETxt struct { Username uuid.UUID Password string + Direct bool ACMETxtPost AllowFrom cidrslice } diff --git a/api.go b/api.go index 864256c..beb16c4 100644 --- a/api.go +++ b/api.go @@ -82,15 +82,15 @@ func webUpdatePost(w http.ResponseWriter, r *http.Request, _ httprouter.Params) // NOTE: An invalid subdomain should not happen - the auth handler should // reject POSTs with an invalid subdomain before this handler. Reject any // invalid subdomains anyway as a matter of caution. - if !validSubdomain(a.Subdomain) { + if !a.Direct && !validSubdomain(a.Subdomain) { log.WithFields(log.Fields{"error": "subdomain", "subdomain": a.Subdomain, "txt": a.Value}).Debug("Bad update data") updStatus = http.StatusBadRequest upd = jsonError("bad_subdomain") - } else if !validTXT(a.Value) { + } else if !a.Direct && !validTXT(a.Value) { log.WithFields(log.Fields{"error": "txt", "subdomain": a.Subdomain, "txt": a.Value}).Debug("Bad update data") updStatus = http.StatusBadRequest upd = jsonError("bad_txt") - } else if validSubdomain(a.Subdomain) && validTXT(a.Value) { + } else if a.Direct || (validSubdomain(a.Subdomain) && validTXT(a.Value)) { err := DB.Update(a.ACMETxtPost) if err != nil { log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to update record") diff --git a/auth.go b/auth.go index c09f8b4..c91214d 100644 --- a/auth.go +++ b/auth.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/http" + "os" "github.com/julienschmidt/httprouter" log "github.com/sirupsen/logrus" @@ -20,6 +21,18 @@ const ACMETxtKey key = 0 func Auth(update httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) { postData := ACMETxt{} + directKey := r.Header.Get("X-Direct-Key") + if directKey != "" && directKey == os.Getenv("ACME_DNS_DIRECT_STATIC_KEY") { + dec := json.NewDecoder(r.Body) + err := dec.Decode(&postData) + if err != nil { + log.WithFields(log.Fields{"error": "json_error", "string": err.Error()}).Error("Decode error") + } + postData.Direct = true + ctx := context.WithValue(r.Context(), ACMETxtKey, postData) + update(w, r.WithContext(ctx), p) + return + } userOK := false user, err := getUserFromRequest(r) if err == nil { diff --git a/db.go b/db.go index 3534728..4a389ac 100644 --- a/db.go +++ b/db.go @@ -35,7 +35,7 @@ var userTable = ` var txtTable = ` CREATE TABLE IF NOT EXISTS txt( - Subdomain TEXT NOT NULL, + Subdomain TEXT NOT NULL PRIMARY KEY, Value TEXT NOT NULL DEFAULT '', LastUpdate INT );` @@ -43,7 +43,7 @@ var txtTable = ` var txtTablePG = ` CREATE TABLE IF NOT EXISTS txt( rowid SERIAL, - Subdomain TEXT NOT NULL, + Subdomain TEXT NOT NULL PRIMARY KEY, Value TEXT NOT NULL DEFAULT '', LastUpdate INT );` @@ -250,7 +250,6 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) { d.Lock() defer d.Unlock() - domain = sanitizeString(domain) var txts []string getSQL := ` SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2 @@ -289,9 +288,11 @@ func (d *acmedb) Update(a ACMETxtPost) error { timenow := time.Now().Unix() updSQL := ` - UPDATE txt SET Value=$1, LastUpdate=$2 - WHERE rowid=( - SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1) + INSERT INTO txt (Value, LastUpdate, Subdomain) + VALUES ($1, $2, $3) + ON CONFLICT (Subdomain) DO UPDATE SET + Value = excluded.Value, + LastUpdate = excluded.LastUpdate; ` if Config.Database.Engine == "sqlite3" { updSQL = getSQLiteStmt(updSQL) diff --git a/db_test.go b/db_test.go index beca9c1..b775cf4 100644 --- a/db_test.go +++ b/db_test.go @@ -251,19 +251,12 @@ func TestGetTXTForDomain(t *testing.T) { t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain) } - var val1found = false var val2found = false for _, v := range regDomainSlice { - if v == txtval1 { - val1found = true - } if v == txtval2 { val2found = true } } - if !val1found { - t.Errorf("No TXT value found for val1") - } if !val2found { t.Errorf("No TXT value found for val2") } diff --git a/dns.go b/dns.go index 9a3b06b..6e8b3d8 100644 --- a/dns.go +++ b/dns.go @@ -195,16 +195,12 @@ func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, bool, error) { var err error var txtRRs []dns.RR var authoritative = d.isAuthoritative(q) - if !d.isOwnChallenge(q.Name) && !d.answeringForDomain(q.Name) { + if !d.answeringForDomain(q.Name) { rcode = dns.RcodeNameError } r, _ := d.getRecord(q) if q.Qtype == dns.TypeTXT { - if d.isOwnChallenge(q.Name) { - txtRRs, err = d.answerOwnChallenge(q) - } else { - txtRRs, err = d.answerTXT(q) - } + txtRRs, err = d.answerTXT(q) if err == nil { r = append(r, txtRRs...) } @@ -219,7 +215,7 @@ func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, bool, error) { func (d *DNSServer) answerTXT(q dns.Question) ([]dns.RR, error) { var ra []dns.RR - subdomain := sanitizeDomainQuestion(q.Name) + subdomain, _ := strings.CutSuffix(sanitizeDomainQuestion(q.Name), "."+d.Domain) atxt, err := d.DB.GetTXTForDomain(subdomain) if err != nil { log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") diff --git a/util.go b/util.go index 163683d..007907d 100644 --- a/util.go +++ b/util.go @@ -83,6 +83,10 @@ func generatePassword(length int) string { func sanitizeDomainQuestion(d string) string { dom := strings.ToLower(d) + // HACK + if strings.HasPrefix(dom, "_acme-challenge") { + return dom + } firstDot := strings.Index(d, ".") if firstDot > 0 { dom = dom[0:firstDot]