summaryrefslogtreecommitdiffstats
path: root/Monitoring/MonitoringService/Driver/SshTunnel.py
blob: f2b7d3922f03f42b800c92712bfd51c99562c920 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
'''
Created on Jan 14, 2013

@author: steger
'''

import select
import SocketServer
from Driver import Driver
from SshExec import SshDriver
from threading import Thread
from SshExec import SshExec

class SshTunnel(SshDriver):
    '''
    @summary: this class extends L{SshDriver} and establishes a connection 
    to the requested SSH server and sets up local port
    forwarding (the openssh -L option) from a local port through a tunneled
    connection to a destination reachable from the SSH server machine.
    @ivar t: the thread container
    @type t: threading.Thread or None
    '''

    class ForwardServer (SocketServer.ThreadingTCPServer):
        daemon_threads = True
        allow_reuse_address = True
    
    class Handler (SocketServer.BaseRequestHandler):
        def handle(self):
            try:
                chan = self.ssh_transport.open_channel('direct-tcpip',
                                                       (self.chain_host, self.chain_port),
                                                       self.request.getpeername())
            except Exception, e:
                Driver.logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host,
                                                                  self.chain_port,
                                                                  repr(e)))
                return
            if chan is None:
                Driver.logger.debug('Incoming request to %s:%d was rejected by the SSH server.' %
                        (self.chain_host, self.chain_port))
                return
    
            Driver.logger.debug('Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
                                                                chan.getpeername(), (self.chain_host, self.chain_port)))
            while True:
                r, _, _ = select.select([self.request, chan], [], [])
                if self.request in r:
                    data = self.request.recv(1024)
                    if len(data) == 0:
                        break
                    chan.send(data)
                if chan in r:
                    data = chan.recv(1024)
                    if len(data) == 0:
                        break
                    self.request.send(data)
            chan.close()
            self.request.close()
            Driver.logger.debug('Tunnel closed from %r' % (self.request.getpeername(),))
    
    def __init__(self):
        '''
        @summary: allocates thread pointer container
        '''
        SshDriver.__init__(self)
        self.t = None
        
    def connect(self, host, credential, localport, port, remoteserver, remoteport, known_host = None):
        '''
        @summary: set up the tunnel connection
        @param host: the host name of the remote server acting a port forwarder
        @type host: str
        @param credential: the secret to use for connection set up
        @type credential: L{Credential}
        @param localport: the local port entry mapped to the remoteport
        @type localport: int
        @param port: the port of the forwarder ssh server
        @type port: int
        @param remoteserver: the sink of the tunnel
        @type remoteserver: str
        @param remoteport: the port of the tunnel sink
        @type remoteport: int
        @param known_host: a file name containing host signatures to check, if None AutoAddPolicy applies 
        @type known_host: str
        '''
        SshDriver.connect(self, host, credential, port, known_host)
        self.logger.info('Now forwarding port %d to %s:%d ...' % (localport, remoteserver, remoteport))
        self.t = Thread(target = self._tran, kwargs = {'localport': localport, 'remoteserver': remoteserver, 'remoteport': remoteport})
        self.t.daemon = True
        self.t.start()

    def _tran(self, localport, remoteserver, remoteport):
        '''
        @summary: thread worker to transport data over the tunnel
        @param localport: the local port entry mapped to the remoteport
        @type localport: int
        @param remoteserver: the sink of the tunnel
        @type remoteserver: str
        @param remoteport: the port of the tunnel sink
        @type remoteport: int
        '''
        try:
            # this is a little convoluted, but lets me configure things for the Handler
            # object.  (SocketServer doesn't give Handlers any way to access the outer
            # server normally.)
            class SubHander (self.Handler):
                chain_host = remoteserver
                chain_port = remoteport
                ssh_transport = self.client.get_transport()
            self.service = self.ForwardServer(('', localport), SubHander)
            self.service.serve_forever()
        except KeyboardInterrupt:
            self.logger.debug('C-c: Port forwarding stopped.')
            self.close()
    
    def close(self):
        '''
        @summary: stops the thread and tears down the tunnel
        '''
        if self.t is None:
            return
        self.t.join(timeout = self.timeout)
        self.t = None
        self.service.shutdown()
        self.logger.info('Port forwarding stopped @ %s.' % self.host)
        SshDriver.close(self)


class SshExecTunnel(SshTunnel):
    '''
    @summary: an extension of the L{SshTunnel} driver to execute commands 
    on the remote machine accessed via the tunnel
    @ivar command: the string representation of the command to run
    @type command: str
    @ivar localdriver: the representation of an ssh client connecting over an existing ssh tunnel
    @type localdriver: L{SshExec}
    '''

    def __init__(self, host, credential, localport, port, remoteserver, remoteport, remotecredential = None, command = "echo helloworld @ `hostname`", known_host = None):
        '''
        @summary: initializes an ssh connection and stores a default command
        @param host: the host name of the remote server acting a port forwarder
        @type host: str
        @param credential: the secret to use for tunnel set up
        @type credential: L{Credential}
        @param localport: the local port entry mapped to the remoteport
        @type localport: int
        @param port: the port of the forwarder ssh server
        @type port: int
        @param remoteserver: the sink of the tunnel
        @type remoteserver: str
        @param remoteport: the port of the tunnel sink
        @type remoteport: int
        @param remotecredential: the secret to use for connection set up, if None then we fall back to the credential
        @type remotecredential: L{Credential} or None
        @param command: the default remote command
        @type command: str
        @param known_host: a file name containing host signatures to check, if None AutoAddPolicy applies 
        @type known_host: str
        '''
        SshTunnel.__init__(self)
        self.connect(host, credential, localport, port, remoteserver, remoteport, known_host)
        self.command = command
        if remotecredential is None:
            remotecredential = credential
        self.localdriver = SshExec(host = 'localhost', credential = remotecredential, port = localport, command = command, known_host = None)
        self.logger.info("connected over tunnel")
    
    def execute(self, command = None):
        '''
        @summary: executes a remote command
        @param command: the command to run, if None, the default command is issued
        @type command: str or None
        @return: the standard output of the command run
        @rtype: paramico.ChannelFile
        '''
        return self.localdriver.execute(command)