123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- # Copyright: (c) OpenSpug Organization. https://github.com/openspug/spug
- # Copyright: (c) <spug.dev@gmail.com>
- # Released under the AGPL-3.0 License.
- from paramiko.client import SSHClient, AutoAddPolicy
- from paramiko.rsakey import RSAKey
- from paramiko.ssh_exception import AuthenticationException
- from io import StringIO
- import base64
- import time
- import re
- class SSH:
- def __init__(self, hostname, port=22, username='root', pkey=None, password=None, default_env=None,
- connect_timeout=10):
- self.stdout = None
- self.client = None
- self.channel = None
- self.sftp = None
- self.eof = 'Spug EOF 2108111926'
- self.default_env = self._make_env_command(default_env)
- self.regex = re.compile(r'Spug EOF 2108111926 (-?\d+)[\r\n]?')
- self.arguments = {
- 'hostname': hostname,
- 'port': port,
- 'username': username,
- 'password': password,
- 'pkey': RSAKey.from_private_key(StringIO(pkey)) if isinstance(pkey, str) else pkey,
- 'timeout': connect_timeout,
- 'banner_timeout': 30
- }
- @staticmethod
- def generate_key():
- key_obj = StringIO()
- key = RSAKey.generate(2048)
- key.write_private_key(key_obj)
- return key_obj.getvalue(), 'ssh-rsa ' + key.get_base64()
- def get_client(self):
- if self.client is not None:
- return self.client
- self.client = SSHClient()
- self.client.set_missing_host_key_policy(AutoAddPolicy)
- self.client.connect(**self.arguments)
- return self.client
- def ping(self):
- return True
- def add_public_key(self, public_key):
- command = f'mkdir -p -m 700 ~/.ssh && \
- echo {public_key!r} >> ~/.ssh/authorized_keys && \
- chmod 600 ~/.ssh/authorized_keys'
- exit_code, out = self.exec_command_raw(command)
- if exit_code != 0:
- raise Exception(f'add public key error: {out}')
- def exec_command_raw(self, command, environment=None):
- channel = self.client.get_transport().open_session()
- if environment:
- channel.update_environment(environment)
- channel.set_combine_stderr(True)
- channel.exec_command(command)
- code, output = channel.recv_exit_status(), channel.recv(-1)
- try:
- output = output.decode()
- except UnicodeDecodeError:
- output = output.decode(encoding='GBK')
- return code, output
- def exec_command(self, command, environment=None):
- channel = self._get_channel()
- command = self._handle_command(command, environment)
- channel.send(command)
- out, exit_code = '', -1
- for line in self.stdout:
- match = self.regex.search(line)
- if match:
- exit_code = int(match.group(1))
- line = line[:match.start()]
- out += line
- break
- out += line
- return exit_code, out
- def _win_exec_command_with_stream(self, command, environment=None):
- channel = self.client.get_transport().open_session()
- if environment:
- channel.update_environment(environment)
- channel.set_combine_stderr(True)
- channel.get_pty(width=102)
- channel.exec_command(command)
- stdout = channel.makefile("rb", -1)
- out = stdout.readline()
- while out:
- yield channel.exit_status, out.decode()
- out = stdout.readline()
- yield channel.recv_exit_status(), out.decode()
- def exec_command_with_stream(self, command, environment=None):
- channel = self._get_channel()
- command = self._handle_command(command, environment)
- channel.send(command)
- exit_code, line = -1, ''
- while True:
- line = channel.recv(8196).decode()
- if not line:
- break
- match = self.regex.search(line)
- if match:
- exit_code = int(match.group(1))
- line = line[:match.start()]
- break
- yield exit_code, line
- yield exit_code, line
- def put_file(self, local_path, remote_path):
- sftp = self._get_sftp()
- sftp.put(local_path, remote_path)
- def put_file_by_fl(self, fl, remote_path, callback=None):
- sftp = self._get_sftp()
- sftp.putfo(fl, remote_path, callback=callback)
- def list_dir_attr(self, path):
- sftp = self._get_sftp()
- return sftp.listdir_attr(path)
- def remove_file(self, path):
- sftp = self._get_sftp()
- sftp.remove(path)
- def _get_channel(self):
- if self.channel:
- return self.channel
- counter = 0
- self.channel = self.client.invoke_shell()
- command = 'export PS1= && stty -echo; unsetopt zle; set -e\n'
- if self.default_env:
- command += f'{self.default_env}\n'
- command += f'echo {self.eof} $?\n'
- self.channel.send(command.encode())
- while True:
- if self.channel.recv_ready():
- line = self.channel.recv(8196).decode()
- if self.regex.search(line):
- self.stdout = self.channel.makefile('r')
- break
- elif counter >= 100:
- self.client.close()
- raise Exception('Wait spug response timeout')
- else:
- counter += 1
- time.sleep(0.1)
- return self.channel
- def _get_sftp(self):
- if self.sftp:
- return self.sftp
- self.sftp = self.client.open_sftp()
- return self.sftp
- def _break(self):
- time.sleep(5)
- command = f'\x03 echo {self.eof} -1\n'
- self.channel.send(command.encode())
- def _make_env_command(self, environment):
- if not environment:
- return None
- str_envs = []
- for k, v in environment.items():
- k = k.replace('-', '_')
- if isinstance(v, str):
- v = v.replace("'", "'\"'\"'")
- str_envs.append(f"{k}='{v}'")
- str_envs = ' '.join(str_envs)
- return f'export {str_envs}'
- def _handle_command(self, command, environment):
- new_command = f'trap \'echo {self.eof} $?; rm -f $SPUG_EXEC_FILE\' EXIT\n'
- env_command = self._make_env_command(environment)
- if env_command:
- new_command += f'{env_command}\n'
- new_command += command
- b64_command = base64.standard_b64encode(new_command.encode())
- commands = 'export SPUG_EXEC_FILE=$(mktemp)\n'
- commands += f'echo {b64_command.decode()} | base64 -di > $SPUG_EXEC_FILE\n'
- commands += 'bash $SPUG_EXEC_FILE\n'
- return commands
- def __enter__(self):
- self.get_client()
- transport = self.client.get_transport()
- if 'windows' in transport.remote_version.lower():
- self.exec_command = self.exec_command_raw
- self.exec_command_with_stream = self._win_exec_command_with_stream
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.client.close()
- self.client = None
|