diff --git a/breakglass.go b/breakglass.go index 92e8d98..0b39cbc 100644 --- a/breakglass.go +++ b/breakglass.go @@ -8,12 +8,14 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "encoding/json" "encoding/pem" "flag" "fmt" "io/ioutil" "log" "net" + "net/http" "os" "strings" "syscall" @@ -113,6 +115,79 @@ func createHostKey(path string) (ssh.Signer, error) { return ssh.NewSignerFromKey(key) } +func buildTimestamp() (string, error) { + var statusReply struct { + BuildTimestamp string `json:"BuildTimestamp"` + } + pw, err := os.ReadFile("/etc/gokr-pw.txt") + if err != nil { + return "", err + } + req, err := http.NewRequest("GET", "http://gokrazy:"+strings.TrimSpace(string(pw))+"@localhost/", nil) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + b, _ := ioutil.ReadAll(resp.Body) + return "", fmt.Errorf("unexpected HTTP status code: got %v, want %v (body: %s)", resp.Status, want, strings.TrimSpace(string(b))) + } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + if err := json.Unmarshal(b, &statusReply); err != nil { + return "", err + } + return statusReply.BuildTimestamp, nil +} + +var motd string + +func initMOTD() error { + if !*enableBanner { + return nil + } + + hostname, err := os.Hostname() + if err != nil { + log.Printf("os.Hostname(): %v", err) + hostname = "gokrazy" + } + const maxSpace = " " + if len(hostname) > len(maxSpace) { + hostname = hostname[:len(maxSpace)] + } + hostname += `"` + if padding := len(maxSpace) - len(hostname); padding > 0 { + hostname += strings.Repeat(" ", padding) + } + + buildTimestamp, err := buildTimestamp() + if err != nil { + return err + } + + motd = fmt.Sprintf(` __ + .-----.-----| |--.----.---.-.-----.--.--. + | _ | _ | <| _| _ |-- __| | | + |___ |_____|__|__|__| |___._|_____|___ | + |_____| host: "%s |_____| + model: %s + build: %s +`, + hostname, + gokrazy.Model(), + buildTimestamp) + return nil +} + func main() { flag.Parse() log.SetFlags(log.LstdFlags | log.Lshortfile) @@ -127,6 +202,10 @@ func main() { log.Fatal(err) } + if err := initMOTD(); err != nil { + log.Print(err) + } + config := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { if authorizedKeys[string(pubKey.Marshal())] { @@ -135,20 +214,6 @@ func main() { } return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath) }, - BannerCallback: func(conn ssh.ConnMetadata) string { - if !*enableBanner { - return "" - } - bannerMessage := fmt.Sprintf("#\n# Welcome to gokrazy, %s!\n", conn.User()) - bannerInfo := fmt.Sprintf("# This installation is running on a %q!\n#\n", gokrazy.Model()) - maxChars := len(bannerInfo) - if maxChars < len(bannerMessage) { - maxChars = len(bannerMessage) - } - border := strings.Repeat("#", maxChars) + "\n" - bannerMessage = border + bannerMessage + bannerInfo + border - return bannerMessage - }, } signer, err := loadHostKey(*hostKeyPath) diff --git a/ssh.go b/ssh.go index 6913c97..92ca4c9 100644 --- a/ssh.go +++ b/ssh.go @@ -323,6 +323,9 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error { case "shell": req.Payload = []byte("\x00\x00\x00\x02sh") + if motd != "" { + fmt.Fprintf(s.channel.Stderr(), "%s\r\n", strings.ReplaceAll(motd, "\n", "\r\n")) + } fallthrough case "exec":