From b5e3630e4463b7a14b2dd809d6fc1568900ccbbd Mon Sep 17 00:00:00 2001 From: m5r Date: Sun, 27 Oct 2024 23:15:41 +0100 Subject: [PATCH] make the https server aware of freshly renewed certificates without restarting it --- http/server.go | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/http/server.go b/http/server.go index 5bf56ff..b4c3242 100644 --- a/http/server.go +++ b/http/server.go @@ -2,6 +2,7 @@ package http import ( "context" + "crypto/tls" "fmt" "net" "net/http" @@ -141,16 +142,48 @@ func redirectHttpToHttps() { go httpServer.ListenAndServe() } +type CertificateReloader struct { + CertificateFilePath string + KeyFilePath string + certificate *tls.Certificate + lastUpdatedAt time.Time +} + +func (cr *CertificateReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + stat, err := os.Stat(cr.KeyFilePath) + if err != nil { + return nil, fmt.Errorf("failed checking key file modification time: %w", err) + } + + if cr.certificate == nil || stat.ModTime().After(cr.lastUpdatedAt) { + pair, err := tls.LoadX509KeyPair(cr.CertificateFilePath, cr.KeyFilePath) + if err != nil { + return nil, fmt.Errorf("failed loading tls key pair: %w", err) + } + + cr.certificate = &pair + cr.lastUpdatedAt = stat.ModTime() + } + + return cr.certificate, nil +} + +var certificateReloader = &CertificateReloader{ + CertificateFilePath: "./.lego/certs/root/server.pem", + KeyFilePath: "./.lego/certs/root/server.key", +} + func serveHttps() { config := utils.GetConfig() mux := newHttpMux() httpsServer := &http.Server{ - Addr: fmt.Sprintf(":%d", config.HttpsPort), - Handler: mux, + Addr: fmt.Sprintf(":%d", config.HttpsPort), + Handler: mux, + TLSConfig: &tls.Config{GetCertificate: certificateReloader.GetCertificate}, } utils.Logger.Info().Str("https_address", httpsServer.Addr).Msg("Starting up HTTPS server") go func() { - err := httpsServer.ListenAndServeTLS("./.lego/certs/root/server.pem", "./.lego/certs/root/server.key") + err := httpsServer.ListenAndServeTLS("", "") if err != http.ErrServerClosed { utils.Logger.Fatal().Err(err).Msg("Unexpected error received from HTTPS server") }