turn banner (printed before auth) into MOTD (printed after login)

This means the message will be printed only once when using the breakglass
command line tool (which first copies over a tarball, then logs in).

Also switch to fancy ASCII art while we’re at it :)
This commit is contained in:
Michael Stapelberg 2022-07-09 18:38:32 +02:00
parent c21964dfd8
commit c857ec6218
2 changed files with 82 additions and 14 deletions

View File

@ -8,12 +8,14 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strings"
"syscall"
@ -113,6 +115,79 @@ func createHostKey(path string) (ssh.Signer, error) {
return ssh.NewSignerFromKey(key)
}
func buildTimestamp() (string, error) {
var statusReply struct {
BuildTimestamp string `json:"BuildTimestamp"`
}
pw, err := os.ReadFile("/etc/gokr-pw.txt")
if err != nil {
return "", err
}
req, err := http.NewRequest("GET", "http://gokrazy:"+strings.TrimSpace(string(pw))+"@localhost/", nil)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if got, want := resp.StatusCode, http.StatusOK; got != want {
b, _ := ioutil.ReadAll(resp.Body)
return "", fmt.Errorf("unexpected HTTP status code: got %v, want %v (body: %s)", resp.Status, want, strings.TrimSpace(string(b)))
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
if err := json.Unmarshal(b, &statusReply); err != nil {
return "", err
}
return statusReply.BuildTimestamp, nil
}
var motd string
func initMOTD() error {
if !*enableBanner {
return nil
}
hostname, err := os.Hostname()
if err != nil {
log.Printf("os.Hostname(): %v", err)
hostname = "gokrazy"
}
const maxSpace = " "
if len(hostname) > len(maxSpace) {
hostname = hostname[:len(maxSpace)]
}
hostname += `"`
if padding := len(maxSpace) - len(hostname); padding > 0 {
hostname += strings.Repeat(" ", padding)
}
buildTimestamp, err := buildTimestamp()
if err != nil {
return err
}
motd = fmt.Sprintf(` __
.-----.-----| |--.----.---.-.-----.--.--.
| _ | _ | <| _| _ |-- __| | |
|___ |_____|__|__|__| |___._|_____|___ |
|_____| host: "%s |_____|
model: %s
build: %s
`,
hostname,
gokrazy.Model(),
buildTimestamp)
return nil
}
func main() {
flag.Parse()
log.SetFlags(log.LstdFlags | log.Lshortfile)
@ -127,6 +202,10 @@ func main() {
log.Fatal(err)
}
if err := initMOTD(); err != nil {
log.Print(err)
}
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeys[string(pubKey.Marshal())] {
@ -135,20 +214,6 @@ func main() {
}
return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath)
},
BannerCallback: func(conn ssh.ConnMetadata) string {
if !*enableBanner {
return ""
}
bannerMessage := fmt.Sprintf("#\n# Welcome to gokrazy, %s!\n", conn.User())
bannerInfo := fmt.Sprintf("# This installation is running on a %q!\n#\n", gokrazy.Model())
maxChars := len(bannerInfo)
if maxChars < len(bannerMessage) {
maxChars = len(bannerMessage)
}
border := strings.Repeat("#", maxChars) + "\n"
bannerMessage = border + bannerMessage + bannerInfo + border
return bannerMessage
},
}
signer, err := loadHostKey(*hostKeyPath)

3
ssh.go
View File

@ -323,6 +323,9 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error {
case "shell":
req.Payload = []byte("\x00\x00\x00\x02sh")
if motd != "" {
fmt.Fprintf(s.channel.Stderr(), "%s\r\n", strings.ReplaceAll(motd, "\n", "\r\n"))
}
fallthrough
case "exec":