Files
azul/front/msn/msnp.py
T
Athena Funderburg 21f38ee3e1 production init
2026-05-26 16:41:23 +00:00

189 lines
4.6 KiB
Python

import io, asyncio, settings
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Any, Optional, Callable, Iterable, Sequence
from urllib.parse import unquote, quote
from util.misc import Logger
from .misc import MSNObj
class MSNPCtrl(metaclass = ABCMeta):
__slots__ = ('logger', 'reader', 'writer', 'peername', 'closed', 'close_callback', 'transport')
logger: Logger
reader: 'MSNPReader'
writer: 'MSNPWriter'
peername: Tuple[str, int]
close_callback: Optional[Callable[[], None]]
closed: bool
transport: Optional[asyncio.WriteTransport]
def __init__(self, logger: Logger) -> None:
self.logger = logger
self.reader = MSNPReader(logger)
self.writer = MSNPWriter(logger)
self.peername = ('0.0.0.0', 1863)
self.close_callback = None
self.closed = False
self.transport = None
@abstractmethod
def on_connect(self) -> None: pass
def data_received(self, data: bytes, *, transport: Optional[asyncio.BaseTransport] = None) -> None:
if transport is None:
transport = self.transport
assert transport is not None
self.peername = transport.get_extra_info('peername')
for m in self.reader.data_received(data):
try:
f = getattr(self, '_m_{}'.format(m[0].lower()))
f(*m[1:])
except Exception as ex:
self.logger.error(ex)
def send_reply(self, *m: Any) -> None:
self.writer.write(m)
transport = self.transport
if transport is not None:
transport.write(self.flush())
def flush(self) -> bytes:
return self.writer.flush()
def _m_out(self) -> None:
self.close()
def close(self) -> None:
if self.closed: return
self.closed = True
if self.close_callback:
self.close_callback()
self._on_close()
@abstractmethod
def _on_close(self) -> None: pass
class MSNPWriter:
__slots__ = ('_logger', '_buf')
_logger: Logger
_buf: io.BytesIO
def __init__(self, logger: Logger) -> None:
self._logger = logger
self._buf = io.BytesIO()
def write(self, m: Iterable[Any]) -> None:
m = list(m)
msnobj = None
data = None
if isinstance(m[-1], bytes):
data = m[-1]
m[-1] = len(data)
elif isinstance(m[-1], MSNObj):
msnobj = m[-1].data
m[-1] = None
mt = tuple(str(x).replace('%', '%25').replace(' ', '%20').replace('\r', '%0D').replace('\n', '%0A') for x in m if x is not None)
if msnobj:
msnobj_encoded = _encode_msnobj(msnobj)
assert msnobj_encoded is not None
mt += (msnobj_encoded,)
_log(self._logger, '[Server]', mt)
w = self._buf.write
w(' '.join(mt).encode('utf-8'))
w(b'\r\n')
if data is not None:
w(data)
if settings.DEBUG_FULL:
print(data)
def flush(self) -> bytes:
data = self._buf.getvalue()
if data:
self._buf = io.BytesIO()
return data
class MSNPReader:
__slots__ = ('logger', '_data', '_i')
logger: Logger
_data: bytes
_i: int
def __init__(self, logger: Logger) -> None:
self.logger = logger
self._data = b''
self._i = 0
def data_received(self, data: bytes) -> Iterable[List[Any]]:
if self._data:
self._data += data
else:
self._data = data
while self._data:
m = self._read_msnp()
if m is None: break
yield m
def _read_msnp(self) -> Optional[List[Any]]:
try:
m, body, e = _msnp_try_decode(self._data, self._i)
except AssertionError:
return None
except Exception:
print("ERR _read_msnp", self._i, self._data)
raise
self._data = self._data[e:]
self._i = 0
_log(self.logger, '[Client]', m)
m = [unquote(x) for x in m]
if body:
m.append(body)
if settings.DEBUG_FULL:
print(body)
return m
def _read_raw(self, n: int) -> bytes:
i = self._i
e = i + n
assert e <= len(self._data)
self._i += n
return self._data[i:e]
def _msnp_try_decode(d: bytes, i: int) -> Tuple[List[Any], Optional[bytes], int]:
# Try to parse an MSNP message from buffer `d` starting at index `i`
# Returns (parsed message, end index)
e = d.find(b'\n', i)
assert e >= 0
e += 1
m_str = d[i:e].decode('utf-8').strip()
assert len(m_str) > 1
m = m_str.split()
body = None
if m[0] in _PAYLOAD_COMMANDS:
n = int(m.pop())
assert e+n <= len(d)
body = d[e:e+n]
e += n
return m, body, e
def _encode_msnobj(msnobj: Optional[str]) -> Optional[str]:
if msnobj is None: return None
return quote(msnobj, safe = '')
_PAYLOAD_COMMANDS = {
'UUX', 'MSG', 'QRY', 'NOT', 'ADL', 'FQY', 'RML', 'UUN', 'UUM', 'PUT', 'DEL', 'SDG', 'VAS', 'SDC',
}
def _log(logger: Logger, pre: str, m: Sequence[Any]) -> None:
if settings.DEBUG_FULL:
logger.debug(pre, *m)
else:
if m[0] in ('UUX', 'MSG', 'SDG', 'ADL', 'SDC'):
logger.debug(pre, *m[:-1], len(m[-1]))
elif m[0] in ('CHG', 'ILN', 'NLN') and 'msnobj' in m[-1]:
logger.debug(pre, *m[:-1], '<truncated>')
else:
logger.debug(pre, *m)