import io, asyncio, settings from abc import ABCMeta, abstractmethod from typing import List, Tuple, Any, Optional, Callable, Iterable, Sequence from urllib.parse import unquote, quote from util.misc import Logger from .misc import MSNObj class MSNPCtrl(metaclass = ABCMeta): __slots__ = ('logger', 'reader', 'writer', 'peername', 'closed', 'close_callback', 'transport') logger: Logger reader: 'MSNPReader' writer: 'MSNPWriter' peername: Tuple[str, int] close_callback: Optional[Callable[[], None]] closed: bool transport: Optional[asyncio.WriteTransport] def __init__(self, logger: Logger) -> None: self.logger = logger self.reader = MSNPReader(logger) self.writer = MSNPWriter(logger) self.peername = ('0.0.0.0', 1863) self.close_callback = None self.closed = False self.transport = None @abstractmethod def on_connect(self) -> None: pass def data_received(self, data: bytes, *, transport: Optional[asyncio.BaseTransport] = None) -> None: if transport is None: transport = self.transport assert transport is not None self.peername = transport.get_extra_info('peername') for m in self.reader.data_received(data): try: f = getattr(self, '_m_{}'.format(m[0].lower())) f(*m[1:]) except Exception as ex: self.logger.error(ex) def send_reply(self, *m: Any) -> None: self.writer.write(m) transport = self.transport if transport is not None: transport.write(self.flush()) def flush(self) -> bytes: return self.writer.flush() def _m_out(self) -> None: self.close() def close(self) -> None: if self.closed: return self.closed = True if self.close_callback: self.close_callback() self._on_close() @abstractmethod def _on_close(self) -> None: pass class MSNPWriter: __slots__ = ('_logger', '_buf') _logger: Logger _buf: io.BytesIO def __init__(self, logger: Logger) -> None: self._logger = logger self._buf = io.BytesIO() def write(self, m: Iterable[Any]) -> None: m = list(m) msnobj = None data = None if isinstance(m[-1], bytes): data = m[-1] m[-1] = len(data) elif isinstance(m[-1], MSNObj): msnobj = m[-1].data m[-1] = None mt = tuple(str(x).replace('%', '%25').replace(' ', '%20').replace('\r', '%0D').replace('\n', '%0A') for x in m if x is not None) if msnobj: msnobj_encoded = _encode_msnobj(msnobj) assert msnobj_encoded is not None mt += (msnobj_encoded,) _log(self._logger, '[Server]', mt) w = self._buf.write w(' '.join(mt).encode('utf-8')) w(b'\r\n') if data is not None: w(data) if settings.DEBUG_FULL: print(data) def flush(self) -> bytes: data = self._buf.getvalue() if data: self._buf = io.BytesIO() return data class MSNPReader: __slots__ = ('logger', '_data', '_i') logger: Logger _data: bytes _i: int def __init__(self, logger: Logger) -> None: self.logger = logger self._data = b'' self._i = 0 def data_received(self, data: bytes) -> Iterable[List[Any]]: if self._data: self._data += data else: self._data = data while self._data: m = self._read_msnp() if m is None: break yield m def _read_msnp(self) -> Optional[List[Any]]: try: m, body, e = _msnp_try_decode(self._data, self._i) except AssertionError: return None except Exception: print("ERR _read_msnp", self._i, self._data) raise self._data = self._data[e:] self._i = 0 _log(self.logger, '[Client]', m) m = [unquote(x) for x in m] if body: m.append(body) if settings.DEBUG_FULL: print(body) return m def _read_raw(self, n: int) -> bytes: i = self._i e = i + n assert e <= len(self._data) self._i += n return self._data[i:e] def _msnp_try_decode(d: bytes, i: int) -> Tuple[List[Any], Optional[bytes], int]: # Try to parse an MSNP message from buffer `d` starting at index `i` # Returns (parsed message, end index) e = d.find(b'\n', i) assert e >= 0 e += 1 m_str = d[i:e].decode('utf-8').strip() assert len(m_str) > 1 m = m_str.split() body = None if m[0] in _PAYLOAD_COMMANDS: n = int(m.pop()) assert e+n <= len(d) body = d[e:e+n] e += n return m, body, e def _encode_msnobj(msnobj: Optional[str]) -> Optional[str]: if msnobj is None: return None return quote(msnobj, safe = '') _PAYLOAD_COMMANDS = { 'UUX', 'MSG', 'QRY', 'NOT', 'ADL', 'FQY', 'RML', 'UUN', 'UUM', 'PUT', 'DEL', 'SDG', 'VAS', 'SDC', } def _log(logger: Logger, pre: str, m: Sequence[Any]) -> None: if settings.DEBUG_FULL: logger.debug(pre, *m) else: if m[0] in ('UUX', 'MSG', 'SDG', 'ADL', 'SDC'): logger.debug(pre, *m[:-1], len(m[-1])) elif m[0] in ('CHG', 'ILN', 'NLN') and 'msnobj' in m[-1]: logger.debug(pre, *m[:-1], '') else: logger.debug(pre, *m)