aboutsummaryrefslogtreecommitdiff
path: root/ssh_negotiate_term
blob: cedbd6ef0d322641346963a40da9818a0c64e746 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/python3
"""
Wrap ssh(1) commands, parse the options, and check whether the hostname
argument is either an IPv4/IPv6 address or matches the hostname pattern for
network equipment rather than servers.  If so, and we're using a TERM type
that's not well-supported on network gear, send a different TERM setting with
a more broadly-compatible TERM type.
"""
import argparse
import configparser
import ipaddress
import os
import re
import sys


class SSHArgumentParserError(Exception):
    """
    Exception for SSHArgumentParser to throw on error.
    """


class SSHArgumentParser(argparse.ArgumentParser):
    """
    Subclass argparse.ArgumentParser just so that errors raise exceptions
    rather than exiting.
    """

    def error(self, message):
        raise SSHArgumentParserError(message)


class SSHNegotiateTerm():
    """
    Using a class, to encapsulate a fair bit of system state that gets injected
    into this.
    """

    def __init__(self, _os=os, _sys=sys, config_path=None):
        self._os = _os
        self._sys = _sys

        config = configparser.ConfigParser()
        config['patterns'] = {}
        config['ssh'] = {}
        config['ssh']['path'] = '/usr/bin/ssh'
        config['translations'] = {}
        if not config_path:
            config_path = self._os.path.expanduser(
                    '~/.config/ssh_negotiate_term/config')
        config.read(self._os.path.expanduser(config_path))
        self.config = config

    def run(self):
        """
        Build argument vector for the real SSH command from what we were passed; if
        there's a compatible TERM string we should use for this call, add it with an
        ssh(1) option -o string.
        """
        args = [self.config['ssh']['path']]
        term = self.get_compatible_term()
        if term:
            args.extend(['-o', f'SetEnv=TERM={term}'])
        args.extend(self._sys.argv[1:])
        self._os.execv(self.config['ssh']['path'], args)

    def get_hostname(self):
        """
        Given an argument vector for an ssh(1) command---including the `ssh`
        command itself---iterate through it until we find what should be the
        hostname, taking account of argument options.  Return 'None' on fail.
        """

        # Set up a parser object, a subclass of the stock argparse one.
        parser = SSHArgumentParser()
        # Handle boolean flags.
        for letter in list('46AaCfGgKkMNnqsTtVvXxYy'):
            parser.add_argument('-' + letter, action='store_true')
        # Handle options that require arguments.
        for letter in list('BbcDEeFIiJLlmOopQRSWw'):
            parser.add_argument('-' + letter)
        # Look for the hostname, which should always be required.
        parser.add_argument('hostname')
        # Note that we ignore the command.

        # Try to parse the command line arguments after the program name.  If we
        # can't make sense of it, the subclass should throw our custom exception,
        # which we'll catch and then return 'None' instead, meaning we can't tell
        # what the hostname is, if there is one at all.
        #
        try:
            args = parser.parse_args()
            return args.hostname
        except SSHArgumentParserError:
            return None

    def get_compatible_term(self):
        """
        If there's a better TERM string to use for this hostname, return it; otherwise,
        return None, meaning that we shouldn't export a different TERM for the actual
        ssh(1) command run.
        """

        # Require that there's a translation for this TERM string to use.
        try:
            term = self.config['translations'][self._os.environ['TERM']]
        except KeyError:
            return None

        # Require stdin, stdout, and stderr all point to a terminal.  Yes, I know
        # this might be more Pythonic as a list comprehension, but I tried that and
        # found it harder to read.
        #
        for descriptor in [
                self._sys.stdin,
                self._sys.stdout,
                self._sys.stderr]:
            if not descriptor.isatty():
                return None

        # Require that we have at least one argument to ssh(1); we'll assume the
        # last argument is the hostname.  This is just an heuristic, and will often
        # not be true; it'll often be a command to be run on the remote system, but
        # in such a case that's done for batch jobs rather than interactive
        # processes, and it therefore often doesn't matter what the TERM is.
        #
        hostname = self.get_hostname()
        if not hostname:
            return None

        # If the hostname parses as an IPv4 or IPv6 address, downgrade the TERM for
        # compatibility.
        try:
            ipaddress.ip_address(hostname)
            return term
        except ValueError:
            pass

        # If the hostname matches the network equipment hostname pattern, downgrade
        # the TERM for compatibility.
        for pattern in self.config['patterns']:
            if re.search(self.config['patterns'][pattern], hostname):
                return term

        # Default to keeping the defined terminal (i.e., don't downgrade).
        return None


if __name__ == '__main__':
    snt = SSHNegotiateTerm()
    snt.run()