diff options
-rw-r--r-- | config.sample | 5 | ||||
-rwxr-xr-x | ssh_negotiate_term | 151 |
2 files changed, 156 insertions, 0 deletions
diff --git a/config.sample b/config.sample new file mode 100644 index 0000000..9566967 --- /dev/null +++ b/config.sample @@ -0,0 +1,5 @@ +[patterns] +# This will match the exact hostname "examplehost" +example_host = ^examplehost$ +# This will match any host matching a pattern like abc-example-def-123a +network_equipment = ^[a-z]{3}-[a-z]{2,}-[a-z]{2,}-[0-9] diff --git a/ssh_negotiate_term b/ssh_negotiate_term new file mode 100755 index 0000000..cedbd6e --- /dev/null +++ b/ssh_negotiate_term @@ -0,0 +1,151 @@ +#!/usr/bin/python3 +""" +Wrap ssh(1) commands, parse the options, and check whether the hostname +argument is either an IPv4/IPv6 address or matches the hostname pattern for +network equipment rather than servers. If so, and we're using a TERM type +that's not well-supported on network gear, send a different TERM setting with +a more broadly-compatible TERM type. +""" +import argparse +import configparser +import ipaddress +import os +import re +import sys + + +class SSHArgumentParserError(Exception): + """ + Exception for SSHArgumentParser to throw on error. + """ + + +class SSHArgumentParser(argparse.ArgumentParser): + """ + Subclass argparse.ArgumentParser just so that errors raise exceptions + rather than exiting. + """ + + def error(self, message): + raise SSHArgumentParserError(message) + + +class SSHNegotiateTerm(): + """ + Using a class, to encapsulate a fair bit of system state that gets injected + into this. + """ + + def __init__(self, _os=os, _sys=sys, config_path=None): + self._os = _os + self._sys = _sys + + config = configparser.ConfigParser() + config['patterns'] = {} + config['ssh'] = {} + config['ssh']['path'] = '/usr/bin/ssh' + config['translations'] = {} + if not config_path: + config_path = self._os.path.expanduser( + '~/.config/ssh_negotiate_term/config') + config.read(self._os.path.expanduser(config_path)) + self.config = config + + def run(self): + """ + Build argument vector for the real SSH command from what we were passed; if + there's a compatible TERM string we should use for this call, add it with an + ssh(1) option -o string. + """ + args = [self.config['ssh']['path']] + term = self.get_compatible_term() + if term: + args.extend(['-o', f'SetEnv=TERM={term}']) + args.extend(self._sys.argv[1:]) + self._os.execv(self.config['ssh']['path'], args) + + def get_hostname(self): + """ + Given an argument vector for an ssh(1) command---including the `ssh` + command itself---iterate through it until we find what should be the + hostname, taking account of argument options. Return 'None' on fail. + """ + + # Set up a parser object, a subclass of the stock argparse one. + parser = SSHArgumentParser() + # Handle boolean flags. + for letter in list('46AaCfGgKkMNnqsTtVvXxYy'): + parser.add_argument('-' + letter, action='store_true') + # Handle options that require arguments. + for letter in list('BbcDEeFIiJLlmOopQRSWw'): + parser.add_argument('-' + letter) + # Look for the hostname, which should always be required. + parser.add_argument('hostname') + # Note that we ignore the command. + + # Try to parse the command line arguments after the program name. If we + # can't make sense of it, the subclass should throw our custom exception, + # which we'll catch and then return 'None' instead, meaning we can't tell + # what the hostname is, if there is one at all. + # + try: + args = parser.parse_args() + return args.hostname + except SSHArgumentParserError: + return None + + def get_compatible_term(self): + """ + If there's a better TERM string to use for this hostname, return it; otherwise, + return None, meaning that we shouldn't export a different TERM for the actual + ssh(1) command run. + """ + + # Require that there's a translation for this TERM string to use. + try: + term = self.config['translations'][self._os.environ['TERM']] + except KeyError: + return None + + # Require stdin, stdout, and stderr all point to a terminal. Yes, I know + # this might be more Pythonic as a list comprehension, but I tried that and + # found it harder to read. + # + for descriptor in [ + self._sys.stdin, + self._sys.stdout, + self._sys.stderr]: + if not descriptor.isatty(): + return None + + # Require that we have at least one argument to ssh(1); we'll assume the + # last argument is the hostname. This is just an heuristic, and will often + # not be true; it'll often be a command to be run on the remote system, but + # in such a case that's done for batch jobs rather than interactive + # processes, and it therefore often doesn't matter what the TERM is. + # + hostname = self.get_hostname() + if not hostname: + return None + + # If the hostname parses as an IPv4 or IPv6 address, downgrade the TERM for + # compatibility. + try: + ipaddress.ip_address(hostname) + return term + except ValueError: + pass + + # If the hostname matches the network equipment hostname pattern, downgrade + # the TERM for compatibility. + for pattern in self.config['patterns']: + if re.search(self.config['patterns'][pattern], hostname): + return term + + # Default to keeping the defined terminal (i.e., don't downgrade). + return None + + +if __name__ == '__main__': + snt = SSHNegotiateTerm() + snt.run() |