# Copyright: (c) OpenSpug Organization. https://github.com/openspug/spug # Copyright: (c) # Released under the AGPL-3.0 License. from paramiko.client import SSHClient, AutoAddPolicy from paramiko.rsakey import RSAKey from paramiko.ssh_exception import AuthenticationException from io import StringIO import base64 import time import re class SSH: def __init__(self, hostname, port=22, username='root', pkey=None, password=None, default_env=None, connect_timeout=10): self.stdout = None self.client = None self.channel = None self.sftp = None self.eof = 'Spug EOF 2108111926' self.default_env = self._make_env_command(default_env) self.regex = re.compile(r'Spug EOF 2108111926 (-?\d+)[\r\n]?') self.arguments = { 'hostname': hostname, 'port': port, 'username': username, 'password': password, 'pkey': RSAKey.from_private_key(StringIO(pkey)) if isinstance(pkey, str) else pkey, 'timeout': connect_timeout, 'banner_timeout': 30 } @staticmethod def generate_key(): key_obj = StringIO() key = RSAKey.generate(2048) key.write_private_key(key_obj) return key_obj.getvalue(), 'ssh-rsa ' + key.get_base64() def get_client(self): if self.client is not None: return self.client self.client = SSHClient() self.client.set_missing_host_key_policy(AutoAddPolicy) self.client.connect(**self.arguments) return self.client def ping(self): return True def add_public_key(self, public_key): command = f'mkdir -p -m 700 ~/.ssh && \ echo {public_key!r} >> ~/.ssh/authorized_keys && \ chmod 600 ~/.ssh/authorized_keys' exit_code, out = self.exec_command_raw(command) if exit_code != 0: raise Exception(f'add public key error: {out}') def exec_command_raw(self, command, environment=None): channel = self.client.get_transport().open_session() if environment: channel.update_environment(environment) channel.set_combine_stderr(True) channel.exec_command(command) code, output = channel.recv_exit_status(), channel.recv(-1) try: output = output.decode() except UnicodeDecodeError: output = output.decode(encoding='GBK') return code, output def exec_command(self, command, environment=None): channel = self._get_channel() command = self._handle_command(command, environment) channel.send(command) out, exit_code = '', -1 for line in self.stdout: match = self.regex.search(line) if match: exit_code = int(match.group(1)) line = line[:match.start()] out += line break out += line return exit_code, out def _win_exec_command_with_stream(self, command, environment=None): channel = self.client.get_transport().open_session() if environment: channel.update_environment(environment) channel.set_combine_stderr(True) channel.get_pty(width=102) channel.exec_command(command) stdout = channel.makefile("rb", -1) out = stdout.readline() while out: yield channel.exit_status, out.decode() out = stdout.readline() yield channel.recv_exit_status(), out.decode() def exec_command_with_stream(self, command, environment=None): channel = self._get_channel() command = self._handle_command(command, environment) channel.send(command) exit_code, line = -1, '' while True: line = channel.recv(8196).decode() if not line: break match = self.regex.search(line) if match: exit_code = int(match.group(1)) line = line[:match.start()] break yield exit_code, line yield exit_code, line def put_file(self, local_path, remote_path): sftp = self._get_sftp() sftp.put(local_path, remote_path) def put_file_by_fl(self, fl, remote_path, callback=None): sftp = self._get_sftp() sftp.putfo(fl, remote_path, callback=callback) def list_dir_attr(self, path): sftp = self._get_sftp() return sftp.listdir_attr(path) def remove_file(self, path): sftp = self._get_sftp() sftp.remove(path) def _get_channel(self): if self.channel: return self.channel counter = 0 self.channel = self.client.invoke_shell() command = 'export PS1= && stty -echo; unsetopt zle; set -e\n' if self.default_env: command += f'{self.default_env}\n' command += f'echo {self.eof} $?\n' self.channel.send(command.encode()) while True: if self.channel.recv_ready(): line = self.channel.recv(8196).decode() if self.regex.search(line): self.stdout = self.channel.makefile('r') break elif counter >= 100: self.client.close() raise Exception('Wait spug response timeout') else: counter += 1 time.sleep(0.1) return self.channel def _get_sftp(self): if self.sftp: return self.sftp self.sftp = self.client.open_sftp() return self.sftp def _break(self): time.sleep(5) command = f'\x03 echo {self.eof} -1\n' self.channel.send(command.encode()) def _make_env_command(self, environment): if not environment: return None str_envs = [] for k, v in environment.items(): k = k.replace('-', '_') if isinstance(v, str): v = v.replace("'", "'\"'\"'") str_envs.append(f"{k}='{v}'") str_envs = ' '.join(str_envs) return f'export {str_envs}' def _handle_command(self, command, environment): new_command = f'trap \'echo {self.eof} $?; rm -f $SPUG_EXEC_FILE\' EXIT\n' env_command = self._make_env_command(environment) if env_command: new_command += f'{env_command}\n' new_command += command b64_command = base64.standard_b64encode(new_command.encode()) commands = 'export SPUG_EXEC_FILE=$(mktemp)\n' commands += f'echo {b64_command.decode()} | base64 -di > $SPUG_EXEC_FILE\n' commands += 'bash $SPUG_EXEC_FILE\n' return commands def __enter__(self): self.get_client() transport = self.client.get_transport() if 'windows' in transport.remote_version.lower(): self.exec_command = self.exec_command_raw self.exec_command_with_stream = self._win_exec_command_with_stream return self def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() self.client = None