Skip to content

ENH: Add Flag -w(--websocket) to support proxy browser mode #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 4.2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions boltkit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,9 @@ async def a():
@click.option("-l", "--listen-addr", type=AddressParamType(), envvar="BOLT_LISTEN_ADDR")
@click.option("-s", "--server-addr", type=AddressListParamType(), envvar="BOLT_SERVER_ADDR")
@click.option("-v", "--verbose", count=True, callback=watch_log, expose_value=False, is_eager=True)
def proxy(server_addr, listen_addr):
proxy_server = ProxyServer(server_addr, listen_addr)
@click.option("-w", "--websocket", is_flag=True, default=False, type=click.BOOL, envvar="BOLT_IS_WEBSOCKET")
def proxy(server_addr, listen_addr, websocket):
proxy_server = ProxyServer(server_addr, listen_addr, websocket)
proxy_server.start()


Expand Down
1 change: 1 addition & 0 deletions boltkit/client/packstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
UINT_8 = ">B" # unsigned 8-bit integer
UINT_16 = ">H" # unsigned 16-bit integer
UINT_32 = ">I" # unsigned 32-bit integer
UINT_64 = ">Q" # unsigned 32-bit integer
FLOAT_64 = ">d" # IEEE double-precision floating-point format


Expand Down
151 changes: 123 additions & 28 deletions boltkit/server/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""


from itertools import cycle
from logging import getLogger
from socket import socket, SOL_SOCKET, SO_REUSEADDR, AF_INET, AF_INET6
from struct import unpack_from as raw_unpack
Expand All @@ -31,90 +32,181 @@
from boltkit.addressing import Address, AddressList
from boltkit.server.bytetools import h
from boltkit.client import CLIENT, SERVER
from boltkit.client.packstream import UINT_32, Unpackable
from boltkit.client.packstream import UINT_64, UINT_32, UINT_16, Unpackable


log = getLogger("boltkit")


class Peer(object):

def __init__(self, socket, address):
def __init__(self, socket, address, session_idx):
self.socket = socket
self.bs_cache = b''
self.address = address
self.bolt_version = 0
self.session_idx = 0


class ProxyPair(Thread):

def __init__(self, client, server):
def __init__(self, client, server, is_websocket=False):
super(ProxyPair, self).__init__()
self.client = client
self.server = server
log.debug("C: <CONNECT> {} -> {}".format(self.client.address, self.server.address))
log.debug("C: <BOLT> {}".format(h(self.forward_bytes(client, server, 4))))
log.debug("C: <VERSION> {}".format(h(self.forward_bytes(client, server, 16))))
raw_bolt_version = self.forward_bytes(server, client, 4)
self.is_websocket = is_websocket

def get_version(self):
client, server = self.client, self.server
log.debug("C[{}]: <CONNECT> {} -> {}".format(self.client.session_idx, self.client.address, self.server.address))
log.debug("C[{}]: <BOLT> {}".format(self.client.session_idx, h(self.forward_bytes(self.client, self.server, 4, is_websocket=self.is_websocket))))
log.debug("C[{}]: <VERSION> {}".format(self.client.session_idx, h(self.forward_bytes(self.client, self.server, 16, is_websocket=self.is_websocket))))
raw_bolt_version = self.forward_bytes(server, client, 4, is_websocket=self.is_websocket)
bolt_version, = raw_unpack(UINT_32, raw_bolt_version)
bolt_version = (bolt_version % 0x100, bolt_version // 0x100)
self.client.bolt_version = self.server.bolt_version = bolt_version
log.debug("S: <VERSION> {}".format(h(raw_bolt_version)))
log.debug("S[{}]: <VERSION> {}".format(self.server.session_idx, h(bolt_version)))
self.client_messages = {v: k for k, v in CLIENT[self.client.bolt_version].items()}
self.server_messages = {v: k for k, v in SERVER[self.server.bolt_version].items()}

def read_header(self, client, server, TAG):
while True:
line = self.forward_line(client, server)
log.debug("{} Header: {}".format(TAG, line))
if line == b'\r\n':
break

def run(self):
if self.is_websocket:
self.read_header(self.client, self.server, "Request[{}]".format(self.client.session_idx))
self.read_header(self.server, self.client, "Response[{}]".format(self.server.session_idx))
self.get_version()
client = self.client
server = self.server
more = True
while more:
try:
self.forward_exchange(client, server)
self.forward_exchange(client, server, is_websocket=self.is_websocket)
except RuntimeError:
more = False
log.debug("C: <CLOSE>")
log.debug("C[{}]: <CLOSE>".format(self.client.session_idx))

@classmethod
def unmask(cls, mask, bs):
bytes_unmask = []
for m, byte in zip(cycle(mask), bs):
bytes_unmask.append(m ^ byte)
return bytes(bytes_unmask)

@classmethod
def forward_bytes(cls, source, target, size):
data = source.socket.recv(size)
target.socket.sendall(data)
def source_recv_then_forward_target(cls, source, target, size, is_websocket=False):
if is_websocket:
while len(source.bs_cache) < size:
source.bs_cache += cls.decompress_bytes_from_websocket(source, target)
data = source.bs_cache[:size]
source.bs_cache = source.bs_cache[size:]
return data
else:
data = source.socket.recv(size)
target.socket.sendall(data)
return data

@classmethod
def forward_chunk(cls, source, target):
chunk_header = cls.forward_bytes(source, target, 2)
def decompress_bytes_from_websocket(cls, source, target):
'''
Websocket Protocol Defines:
Ref: https://tools.ietf.org/html/rfc6455
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
'''
opthead = cls.source_recv_then_forward_target(source, target, 2)
if len(opthead) < 2:
raise RuntimeError("RECV Empty: May caused by socket loss!")
# FIN = opthead[0] >> 7
# RSV1 = opthead[0] >> 6 & 1
# RSV2 = opthead[0] >> 5 & 1
# RSV3 = opthead[0] >> 4 & 1
# OPCODE = opthead[0] & 0b1111
MASK = opthead[1] >> 7
PAYLOAD_LEN = opthead[1] & 0b01111111
if PAYLOAD_LEN == 126:
PAYLOAD_LEN = raw_unpack(UINT_16, cls.source_recv_then_forward_target(source, target, 2))[0]
elif PAYLOAD_LEN == 127:
PAYLOAD_LEN = raw_unpack(UINT_64, cls.source_recv_then_forward_target(source, target, 8))[0]
if MASK == 0:
data = cls.source_recv_then_forward_target(source, target, PAYLOAD_LEN)
return data
else:
mask = cls.source_recv_then_forward_target(source, target, 4)
mask_data = cls.source_recv_then_forward_target(source, target, PAYLOAD_LEN)
unmask_data = cls.unmask(mask, mask_data)
return unmask_data

@classmethod
def forward_bytes(cls, source, target, size, is_websocket=False):
return cls.source_recv_then_forward_target(source, target, size, is_websocket=is_websocket)


@classmethod
def forward_line(cls, source, target):
data = b''
while True:
data += cls.source_recv_then_forward_target(source, target, 1) # source.socket.recv(1)
if len(data) >= 2 and data[-2:] == b'\r\n':
return data

@classmethod
def forward_chunk(cls, source, target, is_websocket=False):
chunk_header = cls.forward_bytes(source, target, 2, is_websocket)
if not chunk_header:
raise RuntimeError()
chunk_size = chunk_header[0] * 0x100 + chunk_header[1]
return cls.forward_bytes(source, target, chunk_size)
return cls.forward_bytes(source, target, chunk_size, is_websocket)

@classmethod
def forward_message(cls, source, target):
def forward_message(cls, source, target, is_websocket=False):
d = b""
size = -1
while size:
data = cls.forward_chunk(source, target)
data = cls.forward_chunk(source, target, is_websocket)
size = len(data)
d += data
return d

def forward_exchange(self, client, server):
rq_message = self.forward_message(client, server)
def forward_exchange(self, client, server, is_websocket=False):
rq_message = self.forward_message(client, server, is_websocket)
rq_signature = rq_message[1]
rq_data = Unpackable(rq_message[2:]).unpack_all()
log.debug("C: {} {}".format(self.client_messages[rq_signature], " ".join(map(repr, rq_data))))
log.debug("C[{}]: {} {}".format(self.client.session_idx, self.client_messages[rq_signature], " ".join(map(repr, rq_data))))
more = True
while more:
rs_message = self.forward_message(server, client)
rs_message = self.forward_message(server, client, is_websocket)
rs_signature = rs_message[1]
rs_data = Unpackable(rs_message[2:]).unpack_all()
log.debug("S: {} {}".format(self.server_messages[rs_signature], " ".join(map(repr, rs_data))))
log.debug("S[{}]: {} {}".format(self.server.session_idx, self.server_messages[rs_signature], " ".join(map(repr, rs_data))))
more = rs_signature == 0x71


class ProxyServer(Thread):

running = False

def __init__(self, server_addr, listen_addr=None):
def __init__(self, server_addr, listen_addr=None, is_websocket=False):
super(ProxyServer, self).__init__()
self.socket = socket()
self.socket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
Expand All @@ -125,19 +217,22 @@ def __init__(self, server_addr, listen_addr=None):
server_addr.resolve()
self.server_addr = server_addr[0]
self.pairs = []
self.is_websocket = is_websocket

def __del__(self):
self.stop()

def run(self):
self.running = True
session_idx = 0
while self.running:
session_idx += 1
client_socket, client_address = self.socket.accept()
server_socket = socket({2: AF_INET, 4: AF_INET6}[len(self.server_addr)])
server_socket.connect(self.server_addr)
client = Peer(client_socket, client_address)
server = Peer(server_socket, self.server_addr)
pair = ProxyPair(client, server)
client = Peer(client_socket, client_address, session_idx)
server = Peer(server_socket, self.server_addr, session_idx)
pair = ProxyPair(client, server, self.is_websocket)
pair.start()
self.pairs.append(pair)

Expand Down