Skip to content

Commit

Permalink
Reduce server load from the WebSocket server
Browse files Browse the repository at this point in the history
Previously, the server was encoding and compressing each message for every client. This resulted in high server load, to the point that the WebSocket server couldn't keep up and dropped messages.

Unfortunately, the websockets package doesn't officially expose its innards. Therefore, this is a disgusting hack that inspects the enabled extensions, prepares the data, and then sends that directly, all using internal APIs of websockets version 7.0.

Compression (cf. RFC 7692) introduces a further complexity: context takeover. Normally, the compression context is reused across messages, but because clients connect at different times and might not be receiving all messages, the same compressor cannot be used across connections. Therefore, context takeover is disabled here.

This also adds some stats about the throughput.
  • Loading branch information
JustAnotherArchivist committed Dec 1, 2023
1 parent dd5ce62 commit 03b3376
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 21 deletions.
4 changes: 2 additions & 2 deletions INSTALL.backend
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ To run the backend, you will need:
- Bundler
- ExecJS supported runtime (for the dashboard)
(see https://github.com/sstephenson/execjs)
- Python 3.6+ and websockets (for the dashboard WebSocket)
- Python 3.6+ and websockets 7.0 (for the dashboard WebSocket)

(Little known fact: ArchiveBot is made to be as hard as possible to set
up. If you have trouble with these instructions, drop by in IRC for
Expand All @@ -26,7 +26,7 @@ Quick install, for Debian and Debian-esque systems like Ubuntu:
cd ArchiveBot
git submodule update --init
bundle install
pip install websockets # Or apt install python3-websockets, or whichever method you prefer.
pip install websockets==7.0 # Or apt install python3-websockets, or whichever method you prefer, but it must be version 7.0.


** STEP 2: INSTALL REDIS **
Expand Down
100 changes: 81 additions & 19 deletions dashboard/websocket.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import collections
import datetime
import io
import os
import sys
import websockets
import websockets.extensions.permessage_deflate
import websockets.framing


DEBUG = 'WSDEBUG' in os.environ and os.environ['WSDEBUG'] == '1'
Expand All @@ -22,63 +26,121 @@ async def stdin(loop):
return reader


async def stdin_to_amplifier(amplifier, loop):
async def stdin_to_amplifier(amplifier, loop, stats):
reader = await stdin(loop)
while True:
amplifier.send((await reader.readline()).decode('utf-8').strip())
d = await reader.readline()
stats['stdin read'] += len(d)
amplifier.send(d.decode('utf-8').strip())


def websocket_extensions_to_key(extensions):
# Convert a list of websockets extensions into a key, handling PerMessageDeflate objects with the relevant care for server-side compression dedupe
def _inner():
for e in extensions:
if isinstance(e, websockets.extensions.permessage_deflate.PerMessageDeflate) and e.local_no_context_takeover:
yield (websockets.extensions.permessage_deflate.PerMessageDeflate, e.remote_max_window_bits, e.local_max_window_bits, tuple(e.compress_settings.items()))
else:
yield e
return tuple(_inner())


class MessageAmplifier:
def __init__(self):
self.queues = {}
def __init__(self, stats):
self.queues = {} # websocket -> queue
self._stats = stats

def register(self, websocket):
self.queues[websocket] = asyncio.Queue(maxsize = 1000)
return self.queues[websocket]
q = asyncio.Queue(maxsize = 1000)
self.queues[websocket] = q
return q

def send(self, message):
for queue in self.queues.values():
#FIXME This abuses internal API of websockets==7.0
# Using the normal `websocket.send` reencodes and recompresses the message for every client.
# So we construct the relevant Frame once instead and push that to the individual queues.
frame = websockets.framing.Frame(fin = True, opcode = websockets.framing.OP_TEXT, data = message.encode('utf-8'))
data = {} # tuple of extensions key → bytes
for websocket, queue in self.queues.items():
extensionsKey = websocket_extensions_to_key(websocket.extensions)
if extensionsKey not in data:
output = io.BytesIO()
frame.write(output.write, mask = False, extensions = websocket.extensions)
data[extensionsKey] = output.getvalue()
self._stats['frame writes'] += len(data[extensionsKey])
try:
queue.put_nowait(message)
queue.put_nowait(data[extensionsKey])
except asyncio.QueueFull:
# Pop one, try again; it should be impossible for this to fail, so no try/except here.
queue.get_nowait()
queue.put_nowait(message)
dropped = queue.get_nowait()
self._stats['dropped'] += len(dropped)
queue.put_nowait(data[extensionsKey])

def unregister(self, websocket):
del self.queues[websocket]


async def websocket_server(amplifier, websocket, path):
async def websocket_server(amplifier, websocket, path, stats):
queue = amplifier.register(websocket)
try:
while True:
await websocket.send(await queue.get())
#FIXME See above; this is write_frame essentially
data = await queue.get()
await websocket.ensure_open()
websocket.writer.write(data)
stats['sent'] += len(data)
if websocket.writer.transport is not None:
if websocket.writer_is_closing():
await asyncio.sleep(0)
try:
async with websocket._drain_lock:
await websocket.writer.drain()
except ConnectionError:
websocket.fail_connection()
await websocket.ensure_open()
except websockets.exceptions.ConnectionClosed: # Silence connection closures
pass
finally:
amplifier.unregister(websocket)


async def print_status(amplifier):
async def print_status(amplifier, stats):
interval = 60
previousUtime = None
previousStats = {}
while True:
currentUtime = os.times().user
cpu = (currentUtime - previousUtime) / 60 * 100 if previousUtime is not None else float('nan')
print(f'{datetime.datetime.now():%Y-%m-%d %H:%M:%S} - {len(amplifier.queues)} clients, {sum(q.qsize() for q in amplifier.queues.values())} total queue size, {cpu:.1f} % CPU, {get_rss()/1048576:.1f} MiB RSS')
cpu = (currentUtime - previousUtime) / interval * 100 if previousUtime is not None else float('nan')
print(f'{datetime.datetime.now():%Y-%m-%d %H:%M:%S} - ' +
', '.join([
f'{len(amplifier.queues)} clients',
f'{sum(q.qsize() for q in amplifier.queues.values())} total queue size',
f'{cpu:.1f} % CPU',
f'{get_rss()/1048576:.1f} MiB RSS',
'throughput: ' + ', '.join(f'{(stats[k] - previousStats.get(k, 0))/interval/1000:.1f} kB/s {k}' for k in stats),
])
)
if DEBUG:
for socket in amplifier.queues:
print(f' {socket.remote_address}: {amplifier.queues[socket].qsize()}')
previousUtime = currentUtime
await asyncio.sleep(60)
previousStats.update(stats)
await asyncio.sleep(interval)


def main():
amplifier = MessageAmplifier()
start_server = websockets.serve(lambda websocket, path: websocket_server(amplifier, websocket, path), None, 4568)
stats = {'stdin read': 0, 'frame writes': 0, 'sent': 0, 'dropped': 0}
amplifier = MessageAmplifier(stats)
# Disable context takeover (cf. RFC 7692) so the compression can be reused
start_server = websockets.serve(
lambda websocket, path: websocket_server(amplifier, websocket, path, stats),
None,
4568,
extensions = [websockets.extensions.permessage_deflate.ServerPerMessageDeflateFactory(server_no_context_takeover = True)]
)
loop = asyncio.get_event_loop()
loop.run_until_complete(start_server)
loop.run_until_complete(asyncio.gather(stdin_to_amplifier(amplifier, loop), print_status(amplifier)))
loop.run_until_complete(asyncio.gather(stdin_to_amplifier(amplifier, loop, stats), print_status(amplifier, stats)))


if __name__ == '__main__':
Expand Down

0 comments on commit 03b3376

Please sign in to comment.