// Package ssh_config provides tools for manipulating SSH config files. // // Importantly, this parser attempts to preserve comments in a given file, so // you can manipulate a `ssh_config` file from a program, if your heart desires. // // The Get() and GetStrict() functions will attempt to read values from // $HOME/.ssh/config, falling back to /etc/ssh/ssh_config. The first argument is // the host name to match on ("example.com"), and the second argument is the key // you want to retrieve ("Port"). The keywords are case insensitive. // // port := ssh_config.Get("myhost", "Port") // // You can also manipulate an SSH config file and then print it or write it back // to disk. // // f, _ := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "config")) // cfg, _ := ssh_config.Decode(f) // for _, host := range cfg.Hosts { // fmt.Println("patterns:", host.Patterns) // for _, node := range host.Nodes { // fmt.Println(node.String()) // } // } // // // Write the cfg back to disk: // fmt.Println(cfg.String()) // // BUG: the Match directive is currently unsupported; parsing a config with // a Match directive will trigger an error. package ssh_config import ( "bytes" "errors" "fmt" "io" "io/ioutil" "os" osuser "os/user" "path/filepath" "regexp" "runtime" "strings" "sync" ) const version = "1.0" var _ = version type configFinder func() string // UserSettings checks ~/.ssh and /etc/ssh for configuration files. The config // files are parsed and cached the first time Get() or GetStrict() is called. type UserSettings struct { IgnoreErrors bool systemConfig *Config systemConfigFinder configFinder userConfig *Config userConfigFinder configFinder loadConfigs sync.Once onceErr error } func homedir() string { user, err := osuser.Current() if err == nil { return user.HomeDir } else { return os.Getenv("HOME") } } func userConfigFinder() string { return filepath.Join(homedir(), ".ssh", "config") } // DefaultUserSettings is the default UserSettings and is used by Get and // GetStrict. It checks both $HOME/.ssh/config and /etc/ssh/ssh_config for keys, // and it will return parse errors (if any) instead of swallowing them. var DefaultUserSettings = &UserSettings{ IgnoreErrors: false, systemConfigFinder: systemConfigFinder, userConfigFinder: userConfigFinder, } func systemConfigFinder() string { return filepath.Join("/", "etc", "ssh", "ssh_config") } func findVal(c *Config, alias, key string) (string, error) { if c == nil { return "", nil } val, err := c.Get(alias, key) if err != nil || val == "" { return "", err } if err := validate(key, val); err != nil { return "", err } return val, nil } // Get finds the first value for key within a declaration that matches the // alias. Get returns the empty string if no value was found, or if IgnoreErrors // is false and we could not parse the configuration file. Use GetStrict to // disambiguate the latter cases. // // The match for key is case insensitive. // // Get is a wrapper around DefaultUserSettings.Get. func Get(alias, key string) string { return DefaultUserSettings.Get(alias, key) } // GetStrict finds the first value for key within a declaration that matches the // alias. If key has a default value and no matching configuration is found, the // default will be returned. For more information on default values and the way // patterns are matched, see the manpage for ssh_config. // // error will be non-nil if and only if a user's configuration file or the // system configuration file could not be parsed, and u.IgnoreErrors is false. // // GetStrict is a wrapper around DefaultUserSettings.GetStrict. func GetStrict(alias, key string) (string, error) { return DefaultUserSettings.GetStrict(alias, key) } // Get finds the first value for key within a declaration that matches the // alias. Get returns the empty string if no value was found, or if IgnoreErrors // is false and we could not parse the configuration file. Use GetStrict to // disambiguate the latter cases. // // The match for key is case insensitive. func (u *UserSettings) Get(alias, key string) string { val, err := u.GetStrict(alias, key) if err != nil { return "" } return val } // GetStrict finds the first value for key within a declaration that matches the // alias. If key has a default value and no matching configuration is found, the // default will be returned. For more information on default values and the way // patterns are matched, see the manpage for ssh_config. // // error will be non-nil if and only if a user's configuration file or the // system configuration file could not be parsed, and u.IgnoreErrors is false. func (u *UserSettings) GetStrict(alias, key string) (string, error) { u.loadConfigs.Do(func() { // can't parse user file, that's ok. var filename string if u.userConfigFinder == nil { filename = userConfigFinder() } else { filename = u.userConfigFinder() } var err error u.userConfig, err = parseFile(filename) //lint:ignore S1002 I prefer it this way if err != nil && os.IsNotExist(err) == false { u.onceErr = err return } if u.systemConfigFinder == nil { filename = systemConfigFinder() } else { filename = u.systemConfigFinder() } u.systemConfig, err = parseFile(filename) //lint:ignore S1002 I prefer it this way if err != nil && os.IsNotExist(err) == false { u.onceErr = err return } }) //lint:ignore S1002 I prefer it this way if u.onceErr != nil && u.IgnoreErrors == false { return "", u.onceErr } val, err := findVal(u.userConfig, alias, key) if err != nil || val != "" { return val, err } val2, err2 := findVal(u.systemConfig, alias, key) if err2 != nil || val2 != "" { return val2, err2 } return Default(key), nil } func parseFile(filename string) (*Config, error) { return parseWithDepth(filename, 0) } func parseWithDepth(filename string, depth uint8) (*Config, error) { b, err := ioutil.ReadFile(filename) if err != nil { return nil, err } return decodeBytes(b, isSystem(filename), depth) } func isSystem(filename string) bool { // TODO: not sure this is the best way to detect a system repo return strings.HasPrefix(filepath.Clean(filename), "/etc/ssh") } // Decode reads r into a Config, or returns an error if r could not be parsed as // an SSH config file. func Decode(r io.Reader) (*Config, error) { b, err := ioutil.ReadAll(r) if err != nil { return nil, err } return decodeBytes(b, false, 0) } func decodeBytes(b []byte, system bool, depth uint8) (c *Config, err error) { defer func() { if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { panic(r) } if e, ok := r.(error); ok && e == ErrDepthExceeded { err = e return } err = errors.New(r.(string)) } }() c = parseSSH(lexSSH(b), system, depth) return c, err } // Config represents an SSH config file. type Config struct { // A list of hosts to match against. The file begins with an implicit // "Host *" declaration matching all hosts. Hosts []*Host depth uint8 position Position } // Get finds the first value in the configuration that matches the alias and // contains key. Get returns the empty string if no value was found, or if the // Config contains an invalid conditional Include value. // // The match for key is case insensitive. func (c *Config) Get(alias, key string) (string, error) { lowerKey := strings.ToLower(key) for _, host := range c.Hosts { if !host.Matches(alias) { continue } for _, node := range host.Nodes { switch t := node.(type) { case *Empty: continue case *KV: // "keys are case insensitive" per the spec lkey := strings.ToLower(t.Key) if lkey == "match" { panic("can't handle Match directives") } if lkey == lowerKey { return t.Value, nil } case *Include: val := t.Get(alias, key) if val != "" { return val, nil } default: return "", fmt.Errorf("unknown Node type %v", t) } } } return "", nil } // String returns a string representation of the Config file. func (c Config) String() string { return marshal(c).String() } func (c Config) MarshalText() ([]byte, error) { return marshal(c).Bytes(), nil } func marshal(c Config) *bytes.Buffer { var buf bytes.Buffer for i := range c.Hosts { buf.WriteString(c.Hosts[i].String()) } return &buf } // Pattern is a pattern in a Host declaration. Patterns are read-only values; // create a new one with NewPattern(). type Pattern struct { str string // Its appearance in the file, not the value that gets compiled. regex *regexp.Regexp not bool // True if this is a negated match } // String prints the string representation of the pattern. func (p Pattern) String() string { return p.str } // Copied from regexp.go with * and ? removed. var specialBytes = []byte(`\.+()|[]{}^$`) func special(b byte) bool { return bytes.IndexByte(specialBytes, b) >= 0 } // NewPattern creates a new Pattern for matching hosts. NewPattern("*") creates // a Pattern that matches all hosts. // // From the manpage, a pattern consists of zero or more non-whitespace // characters, `*' (a wildcard that matches zero or more characters), or `?' (a // wildcard that matches exactly one character). For example, to specify a set // of declarations for any host in the ".co.uk" set of domains, the following // pattern could be used: // // Host *.co.uk // // The following pattern would match any host in the 192.168.0.[0-9] network range: // // Host 192.168.0.? func NewPattern(s string) (*Pattern, error) { if s == "" { return nil, errors.New("ssh_config: empty pattern") } negated := false if s[0] == '!' { negated = true s = s[1:] } var buf bytes.Buffer buf.WriteByte('^') for i := 0; i < len(s); i++ { // A byte loop is correct because all metacharacters are ASCII. switch b := s[i]; b { case '*': buf.WriteString(".*") case '?': buf.WriteString(".?") default: // borrowing from QuoteMeta here. if special(b) { buf.WriteByte('\\') } buf.WriteByte(b) } } buf.WriteByte('$') r, err := regexp.Compile(buf.String()) if err != nil { return nil, err } return &Pattern{str: s, regex: r, not: negated}, nil } // Host describes a Host directive and the keywords that follow it. type Host struct { // A list of host patterns that should match this host. Patterns []*Pattern // A Node is either a key/value pair or a comment line. Nodes []Node // EOLComment is the comment (if any) terminating the Host line. EOLComment string hasEquals bool leadingSpace int // TODO: handle spaces vs tabs here. // The file starts with an implicit "Host *" declaration. implicit bool } // Matches returns true if the Host matches for the given alias. For // a description of the rules that provide a match, see the manpage for // ssh_config. func (h *Host) Matches(alias string) bool { found := false for i := range h.Patterns { if h.Patterns[i].regex.MatchString(alias) { if h.Patterns[i].not { // Negated match. "A pattern entry may be negated by prefixing // it with an exclamation mark (`!'). If a negated entry is // matched, then the Host entry is ignored, regardless of // whether any other patterns on the line match. Negated matches // are therefore useful to provide exceptions for wildcard // matches." return false } found = true } } return found } // String prints h as it would appear in a config file. Minor tweaks may be // present in the whitespace in the printed file. func (h *Host) String() string { var buf bytes.Buffer //lint:ignore S1002 I prefer to write it this way if h.implicit == false { buf.WriteString(strings.Repeat(" ", int(h.leadingSpace))) buf.WriteString("Host") if h.hasEquals { buf.WriteString(" = ") } else { buf.WriteString(" ") } for i, pat := range h.Patterns { buf.WriteString(pat.String()) if i < len(h.Patterns)-1 { buf.WriteString(" ") } } if h.EOLComment != "" { buf.WriteString(" #") buf.WriteString(h.EOLComment) } buf.WriteByte('\n') } for i := range h.Nodes { buf.WriteString(h.Nodes[i].String()) buf.WriteByte('\n') } return buf.String() } // Node represents a line in a Config. type Node interface { Pos() Position String() string } // KV is a line in the config file that contains a key, a value, and possibly // a comment. type KV struct { Key string Value string Comment string hasEquals bool leadingSpace int // Space before the key. TODO handle spaces vs tabs. position Position } // Pos returns k's Position. func (k *KV) Pos() Position { return k.position } // String prints k as it was parsed in the config file. There may be slight // changes to the whitespace between values. func (k *KV) String() string { if k == nil { return "" } equals := " " if k.hasEquals { equals = " = " } line := fmt.Sprintf("%s%s%s%s", strings.Repeat(" ", int(k.leadingSpace)), k.Key, equals, k.Value) if k.Comment != "" { line += " #" + k.Comment } return line } // Empty is a line in the config file that contains only whitespace or comments. type Empty struct { Comment string leadingSpace int // TODO handle spaces vs tabs. position Position } // Pos returns e's Position. func (e *Empty) Pos() Position { return e.position } // String prints e as it was parsed in the config file. func (e *Empty) String() string { if e == nil { return "" } if e.Comment == "" { return "" } return fmt.Sprintf("%s#%s", strings.Repeat(" ", int(e.leadingSpace)), e.Comment) } // Include holds the result of an Include directive, including the config files // that have been parsed as part of that directive. At most 5 levels of Include // statements will be parsed. type Include struct { // Comment is the contents of any comment at the end of the Include // statement. Comment string // an include directive can include several different files, and wildcards directives []string mu sync.Mutex // 1:1 mapping between matches and keys in files array; matches preserves // ordering matches []string // actual filenames are listed here files map[string]*Config leadingSpace int position Position depth uint8 hasEquals bool } const maxRecurseDepth = 5 // ErrDepthExceeded is returned if too many Include directives are parsed. // Usually this indicates a recursive loop (an Include directive pointing to the // file it contains). var ErrDepthExceeded = errors.New("ssh_config: max recurse depth exceeded") func removeDups(arr []string) []string { // Use map to record duplicates as we find them. encountered := make(map[string]bool, len(arr)) result := make([]string, 0) for v := range arr { //lint:ignore S1002 I prefer it this way if encountered[arr[v]] == false { encountered[arr[v]] = true result = append(result, arr[v]) } } return result } // NewInclude creates a new Include with a list of file globs to include. // Configuration files are parsed greedily (e.g. as soon as this function runs). // Any error encountered while parsing nested configuration files will be // returned. func NewInclude(directives []string, hasEquals bool, pos Position, comment string, system bool, depth uint8) (*Include, error) { if depth > maxRecurseDepth { return nil, ErrDepthExceeded } inc := &Include{ Comment: comment, directives: directives, files: make(map[string]*Config), position: pos, leadingSpace: pos.Col - 1, depth: depth, hasEquals: hasEquals, } // no need for inc.mu.Lock() since nothing else can access this inc matches := make([]string, 0) for i := range directives { var path string if filepath.IsAbs(directives[i]) { path = directives[i] } else if system { path = filepath.Join("/etc/ssh", directives[i]) } else { path = filepath.Join(homedir(), ".ssh", directives[i]) } theseMatches, err := filepath.Glob(path) if err != nil { return nil, err } matches = append(matches, theseMatches...) } matches = removeDups(matches) inc.matches = matches for i := range matches { config, err := parseWithDepth(matches[i], depth) if err != nil { return nil, err } inc.files[matches[i]] = config } return inc, nil } // Pos returns the position of the Include directive in the larger file. func (i *Include) Pos() Position { return i.position } // Get finds the first value in the Include statement matching the alias and the // given key. func (inc *Include) Get(alias, key string) string { inc.mu.Lock() defer inc.mu.Unlock() // TODO: we search files in any order which is not correct for i := range inc.matches { cfg := inc.files[inc.matches[i]] if cfg == nil { panic("nil cfg") } val, err := cfg.Get(alias, key) if err == nil && val != "" { return val } } return "" } // String prints out a string representation of this Include directive. Note // included Config files are not printed as part of this representation. func (inc *Include) String() string { equals := " " if inc.hasEquals { equals = " = " } line := fmt.Sprintf("%sInclude%s%s", strings.Repeat(" ", int(inc.leadingSpace)), equals, strings.Join(inc.directives, " ")) if inc.Comment != "" { line += " #" + inc.Comment } return line } var matchAll *Pattern func init() { var err error matchAll, err = NewPattern("*") if err != nil { panic(err) } } func newConfig() *Config { return &Config{ Hosts: []*Host{ &Host{ implicit: true, Patterns: []*Pattern{matchAll}, Nodes: make([]Node, 0), }, }, depth: 0, } }