mirror of
https://git.ugnet.gay/CrossTalk/azul.git
synced 2026-05-27 22:59:49 +00:00
210 lines
6.0 KiB
Python
210 lines
6.0 KiB
Python
import io, asyncio, binascii, struct, settings
|
|
from abc import ABCMeta, abstractmethod
|
|
from typing import Tuple, Any, Optional, Callable, Iterable
|
|
|
|
from core import error
|
|
from util.misc import Logger, MultiDict
|
|
from .misc import YMSGStatus, YMSGService
|
|
|
|
KVS = MultiDict[bytes, bytes]
|
|
|
|
class YMSGCtrlBase(metaclass = ABCMeta):
|
|
__slots__ = ('logger', 'decoder', 'encoder', 'peername', 'closed', 'close_callback', 'transport', 'session_id')
|
|
|
|
logger: Logger
|
|
decoder: 'YMSGDecoder'
|
|
encoder: 'YMSGEncoder'
|
|
peername: Tuple[str, int]
|
|
close_callback: Optional[Callable[[], None]]
|
|
closed: bool
|
|
transport: Optional[asyncio.WriteTransport]
|
|
session_id: int
|
|
|
|
def __init__(self, logger: Logger) -> None:
|
|
self.logger = logger
|
|
self.decoder = YMSGDecoder(logger)
|
|
self.encoder = YMSGEncoder(logger)
|
|
self.peername = ('0.0.0.0', 5050)
|
|
self.closed = False
|
|
self.close_callback = None
|
|
self.transport = None
|
|
self.session_id = 0
|
|
|
|
def data_received(self, data: bytes, *, transport: Optional[asyncio.BaseTransport] = None) -> None:
|
|
if transport is None:
|
|
transport = self.transport
|
|
assert transport is not None
|
|
self.peername = transport.get_extra_info('peername')
|
|
for y in self.decoder.data_received(data):
|
|
try:
|
|
# check version and vendorId
|
|
if y[1] > 18 or y[2] not in (0, 1, 100):
|
|
continue
|
|
if y[4]:
|
|
self.session_id = y[4]
|
|
f = getattr(self, '_y_{}'.format(binascii.hexlify(struct.pack('!H', y[0])).decode()))
|
|
f(*y[1:])
|
|
except Exception as ex:
|
|
self.logger.error(ex)
|
|
|
|
def send_reply(self, service: YMSGService, status: YMSGStatus, session_id: int, kvs: Optional[KVS] = None) -> None:
|
|
if session_id == 0:
|
|
session_id = self.session_id
|
|
try:
|
|
self.encoder.encode(service, status, session_id, kvs)
|
|
except error.DataTooLargeToSend:
|
|
return
|
|
transport = self.transport
|
|
if transport is not None:
|
|
transport.write(self.flush())
|
|
|
|
def flush(self) -> bytes:
|
|
return self.encoder.flush()
|
|
|
|
def close(self, **kwargs: Any) -> None:
|
|
if self.closed: return
|
|
self.closed = True
|
|
|
|
if self.close_callback:
|
|
self.close_callback()
|
|
self._on_close(**kwargs)
|
|
|
|
@abstractmethod
|
|
def _on_close(self, remove_sess_id: bool = True) -> None: pass
|
|
|
|
class YMSGEncoder:
|
|
__slots__ = ('_logger', '_buf')
|
|
|
|
_logger: Logger
|
|
_buf: io.BytesIO
|
|
|
|
def __init__(self, logger: Logger) -> None:
|
|
self._logger = logger
|
|
self._buf = io.BytesIO()
|
|
|
|
def encode(self, service: YMSGService, status: YMSGStatus, session_id: int, kvs: Optional[KVS] = None) -> None:
|
|
payload_list = []
|
|
if kvs is not None:
|
|
k = None # type: Optional[bytes]
|
|
v = None # type: Optional[bytes]
|
|
for k, v in kvs.items():
|
|
payload_list.extend([k, SEP, v, SEP])
|
|
payload = b''.join(payload_list)
|
|
|
|
# TODO: Yahoo!'s servers used to split large payloads into packet chunks,
|
|
# but there's little information on how it was exactly handled.
|
|
# Just drop packets if they're too big (for the length field to handle unfortunately) until we can find a solution.
|
|
|
|
if len(payload) > 0xffff:
|
|
raise error.DataTooLargeToSend()
|
|
|
|
w = self._buf.write
|
|
w(PRE)
|
|
# version number and vendor id are replaced with 0x00000000
|
|
w(b'\x00\x00\x00\x00')
|
|
|
|
# Have to call `int` on these because they might be an IntEnum, which
|
|
# get `repr`'d to `EnumName.ValueName`. Grr.
|
|
w(struct.pack('!HHII', len(payload), int(service), int(status), session_id))
|
|
w(payload)
|
|
|
|
self._logger.debug('[Server]', service, status, session_id)
|
|
if kvs:
|
|
_truncated_kvs(service, kvs)
|
|
|
|
def flush(self) -> bytes:
|
|
data = self._buf.getvalue()
|
|
if data:
|
|
#self._logger.info('<<<', data)
|
|
self._buf = io.BytesIO()
|
|
return data
|
|
|
|
DecodedYMSG = Tuple[YMSGService, int, int, YMSGStatus, int, KVS]
|
|
|
|
class YMSGDecoder:
|
|
__slots__ = ('logger', '_data', '_i')
|
|
|
|
logger: Logger
|
|
_data: bytes
|
|
_i: int
|
|
|
|
def __init__(self, logger: Logger) -> None:
|
|
self.logger = logger
|
|
self._data = b''
|
|
self._i = 0
|
|
|
|
def data_received(self, data: bytes) -> Iterable[DecodedYMSG]:
|
|
if self._data:
|
|
self._data += data
|
|
else:
|
|
self._data = data
|
|
while self._data:
|
|
y = self._ymsg_read()
|
|
if y is None: break
|
|
yield y
|
|
|
|
def _ymsg_read(self) -> Optional[DecodedYMSG]:
|
|
try:
|
|
y, e = _try_decode_ymsg(self._data, self._i)
|
|
except Exception:
|
|
print("ERR _ymsg_read", self._data)
|
|
raise
|
|
|
|
self._data = self._data[e:]
|
|
self._i = 0
|
|
self.logger.debug('[Client]', 'YMSG{}'.format(str(y[1])), y[0], y[3], y[4])
|
|
_truncated_kvs(y[0], y[5])
|
|
return y
|
|
|
|
def _try_decode_ymsg(d: bytes, i: int) -> Tuple[DecodedYMSG, int]:
|
|
kvs = MultiDict() # type: KVS
|
|
|
|
e = 20
|
|
assert len(d[i:]) >= e
|
|
assert d[i:i+4] == PRE
|
|
header = d[i+4:i+e]
|
|
if header[:2] in (b'\x08\x00',b'\x09\x00',b'\x0a\x00'):
|
|
version = struct.unpack('<H', header[:2])[0] # type: int
|
|
else:
|
|
version = struct.unpack('!H', header[:2])[0]
|
|
(vendor_id, n, service, status, session_id) = struct.unpack('!HHHII', header[2:]) # type: Tuple[int, int, int, int, int]
|
|
assert version in YMSG_DIALECTS
|
|
assert e+n <= len(d[i:])
|
|
payload = d[e:e+n]
|
|
if payload:
|
|
parts = payload.split(SEP)
|
|
del parts[-1]
|
|
assert len(parts) % 2 == 0
|
|
for j in range(1, len(parts), 2):
|
|
kvs.add(parts[j-1], parts[j])
|
|
e += n
|
|
return ((YMSGService(service), version, vendor_id, YMSGStatus(status), session_id, kvs), e)
|
|
|
|
def _truncated_kvs(service: YMSGService, kvs: KVS) -> None:
|
|
restricted_keys = set()
|
|
|
|
if service in (YMSGService.AuthResp, YMSGService.List):
|
|
restricted_keys.add(b'59')
|
|
if service in (
|
|
YMSGService.Message, YMSGService.MassMessage, YMSGService.ContactNew, YMSGService.FriendAdd,
|
|
YMSGService.ContactDeny, YMSGService.ConfDecline, YMSGService.ConfMsg,YMSGService.P2PFileXfer, YMSGService.FileTransfer
|
|
):
|
|
restricted_keys.add(b'14')
|
|
if service in (YMSGService.ConfInvite, YMSGService.ConfAddInvite):
|
|
restricted_keys.add(b'58')
|
|
if service in (YMSGService.P2PFileXfer, YMSGService.FileTransfer):
|
|
restricted_keys.add(b'20')
|
|
|
|
if settings.DEBUG:
|
|
for k, v in kvs.items():
|
|
print('{!r} -> {}'.format(k, (v)))
|
|
|
|
PRE = b'YMSG'
|
|
SEP = b'\xC0\x80'
|
|
|
|
YMSG_DIALECTS = [
|
|
# Not actually supported
|
|
18, 17, 16, 8,
|
|
# Actually supported
|
|
15, 14, 13, 12, 11, 10, 9
|
|
] |