aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.sample5
-rwxr-xr-xssh_negotiate_term151
2 files changed, 156 insertions, 0 deletions
diff --git a/config.sample b/config.sample
new file mode 100644
index 0000000..9566967
--- /dev/null
+++ b/config.sample
@@ -0,0 +1,5 @@
+[patterns]
+# This will match the exact hostname "examplehost"
+example_host = ^examplehost$
+# This will match any host matching a pattern like abc-example-def-123a
+network_equipment = ^[a-z]{3}-[a-z]{2,}-[a-z]{2,}-[0-9]
diff --git a/ssh_negotiate_term b/ssh_negotiate_term
new file mode 100755
index 0000000..cedbd6e
--- /dev/null
+++ b/ssh_negotiate_term
@@ -0,0 +1,151 @@
+#!/usr/bin/python3
+"""
+Wrap ssh(1) commands, parse the options, and check whether the hostname
+argument is either an IPv4/IPv6 address or matches the hostname pattern for
+network equipment rather than servers. If so, and we're using a TERM type
+that's not well-supported on network gear, send a different TERM setting with
+a more broadly-compatible TERM type.
+"""
+import argparse
+import configparser
+import ipaddress
+import os
+import re
+import sys
+
+
+class SSHArgumentParserError(Exception):
+ """
+ Exception for SSHArgumentParser to throw on error.
+ """
+
+
+class SSHArgumentParser(argparse.ArgumentParser):
+ """
+ Subclass argparse.ArgumentParser just so that errors raise exceptions
+ rather than exiting.
+ """
+
+ def error(self, message):
+ raise SSHArgumentParserError(message)
+
+
+class SSHNegotiateTerm():
+ """
+ Using a class, to encapsulate a fair bit of system state that gets injected
+ into this.
+ """
+
+ def __init__(self, _os=os, _sys=sys, config_path=None):
+ self._os = _os
+ self._sys = _sys
+
+ config = configparser.ConfigParser()
+ config['patterns'] = {}
+ config['ssh'] = {}
+ config['ssh']['path'] = '/usr/bin/ssh'
+ config['translations'] = {}
+ if not config_path:
+ config_path = self._os.path.expanduser(
+ '~/.config/ssh_negotiate_term/config')
+ config.read(self._os.path.expanduser(config_path))
+ self.config = config
+
+ def run(self):
+ """
+ Build argument vector for the real SSH command from what we were passed; if
+ there's a compatible TERM string we should use for this call, add it with an
+ ssh(1) option -o string.
+ """
+ args = [self.config['ssh']['path']]
+ term = self.get_compatible_term()
+ if term:
+ args.extend(['-o', f'SetEnv=TERM={term}'])
+ args.extend(self._sys.argv[1:])
+ self._os.execv(self.config['ssh']['path'], args)
+
+ def get_hostname(self):
+ """
+ Given an argument vector for an ssh(1) command---including the `ssh`
+ command itself---iterate through it until we find what should be the
+ hostname, taking account of argument options. Return 'None' on fail.
+ """
+
+ # Set up a parser object, a subclass of the stock argparse one.
+ parser = SSHArgumentParser()
+ # Handle boolean flags.
+ for letter in list('46AaCfGgKkMNnqsTtVvXxYy'):
+ parser.add_argument('-' + letter, action='store_true')
+ # Handle options that require arguments.
+ for letter in list('BbcDEeFIiJLlmOopQRSWw'):
+ parser.add_argument('-' + letter)
+ # Look for the hostname, which should always be required.
+ parser.add_argument('hostname')
+ # Note that we ignore the command.
+
+ # Try to parse the command line arguments after the program name. If we
+ # can't make sense of it, the subclass should throw our custom exception,
+ # which we'll catch and then return 'None' instead, meaning we can't tell
+ # what the hostname is, if there is one at all.
+ #
+ try:
+ args = parser.parse_args()
+ return args.hostname
+ except SSHArgumentParserError:
+ return None
+
+ def get_compatible_term(self):
+ """
+ If there's a better TERM string to use for this hostname, return it; otherwise,
+ return None, meaning that we shouldn't export a different TERM for the actual
+ ssh(1) command run.
+ """
+
+ # Require that there's a translation for this TERM string to use.
+ try:
+ term = self.config['translations'][self._os.environ['TERM']]
+ except KeyError:
+ return None
+
+ # Require stdin, stdout, and stderr all point to a terminal. Yes, I know
+ # this might be more Pythonic as a list comprehension, but I tried that and
+ # found it harder to read.
+ #
+ for descriptor in [
+ self._sys.stdin,
+ self._sys.stdout,
+ self._sys.stderr]:
+ if not descriptor.isatty():
+ return None
+
+ # Require that we have at least one argument to ssh(1); we'll assume the
+ # last argument is the hostname. This is just an heuristic, and will often
+ # not be true; it'll often be a command to be run on the remote system, but
+ # in such a case that's done for batch jobs rather than interactive
+ # processes, and it therefore often doesn't matter what the TERM is.
+ #
+ hostname = self.get_hostname()
+ if not hostname:
+ return None
+
+ # If the hostname parses as an IPv4 or IPv6 address, downgrade the TERM for
+ # compatibility.
+ try:
+ ipaddress.ip_address(hostname)
+ return term
+ except ValueError:
+ pass
+
+ # If the hostname matches the network equipment hostname pattern, downgrade
+ # the TERM for compatibility.
+ for pattern in self.config['patterns']:
+ if re.search(self.config['patterns'][pattern], hostname):
+ return term
+
+ # Default to keeping the defined terminal (i.e., don't downgrade).
+ return None
+
+
+if __name__ == '__main__':
+ snt = SSHNegotiateTerm()
+ snt.run()