qemu

FORK: QEMU emulator
git clone https://git.neptards.moe/neptards/qemu.git
Log | Files | Refs | Submodules | LICENSE

protocol.py (18678B)


      1 import asyncio
      2 from contextlib import contextmanager
      3 import os
      4 import socket
      5 from tempfile import TemporaryDirectory
      6 
      7 import avocado
      8 
      9 from qemu.qmp import ConnectError, Runstate
     10 from qemu.qmp.protocol import AsyncProtocol, StateError
     11 from qemu.qmp.util import asyncio_run, create_task
     12 
     13 
     14 class NullProtocol(AsyncProtocol[None]):
     15     """
     16     NullProtocol is a test mockup of an AsyncProtocol implementation.
     17 
     18     It adds a fake_session instance variable that enables a code path
     19     that bypasses the actual connection logic, but still allows the
     20     reader/writers to start.
     21 
     22     Because the message type is defined as None, an asyncio.Event named
     23     'trigger_input' is created that prohibits the reader from
     24     incessantly being able to yield None; this event can be poked to
     25     simulate an incoming message.
     26 
     27     For testing symmetry with do_recv, an interface is added to "send" a
     28     Null message.
     29 
     30     For testing purposes, a "simulate_disconnection" method is also
     31     added which allows us to trigger a bottom half disconnect without
     32     injecting any real errors into the reader/writer loops; in essence
     33     it performs exactly half of what disconnect() normally does.
     34     """
     35     def __init__(self, name=None):
     36         self.fake_session = False
     37         self.trigger_input: asyncio.Event
     38         super().__init__(name)
     39 
     40     async def _establish_session(self):
     41         self.trigger_input = asyncio.Event()
     42         await super()._establish_session()
     43 
     44     async def _do_start_server(self, address, ssl=None):
     45         if self.fake_session:
     46             self._accepted = asyncio.Event()
     47             self._set_state(Runstate.CONNECTING)
     48             await asyncio.sleep(0)
     49         else:
     50             await super()._do_start_server(address, ssl)
     51 
     52     async def _do_accept(self):
     53         if self.fake_session:
     54             self._accepted = None
     55         else:
     56             await super()._do_accept()
     57 
     58     async def _do_connect(self, address, ssl=None):
     59         if self.fake_session:
     60             self._set_state(Runstate.CONNECTING)
     61             await asyncio.sleep(0)
     62         else:
     63             await super()._do_connect(address, ssl)
     64 
     65     async def _do_recv(self) -> None:
     66         await self.trigger_input.wait()
     67         self.trigger_input.clear()
     68 
     69     def _do_send(self, msg: None) -> None:
     70         pass
     71 
     72     async def send_msg(self) -> None:
     73         await self._outgoing.put(None)
     74 
     75     async def simulate_disconnect(self) -> None:
     76         """
     77         Simulates a bottom-half disconnect.
     78 
     79         This method schedules a disconnection but does not wait for it
     80         to complete. This is used to put the loop into the DISCONNECTING
     81         state without fully quiescing it back to IDLE. This is normally
     82         something you cannot coax AsyncProtocol to do on purpose, but it
     83         will be similar to what happens with an unhandled Exception in
     84         the reader/writer.
     85 
     86         Under normal circumstances, the library design requires you to
     87         await on disconnect(), which awaits the disconnect task and
     88         returns bottom half errors as a pre-condition to allowing the
     89         loop to return back to IDLE.
     90         """
     91         self._schedule_disconnect()
     92 
     93 
     94 class LineProtocol(AsyncProtocol[str]):
     95     def __init__(self, name=None):
     96         super().__init__(name)
     97         self.rx_history = []
     98 
     99     async def _do_recv(self) -> str:
    100         raw = await self._readline()
    101         msg = raw.decode()
    102         self.rx_history.append(msg)
    103         return msg
    104 
    105     def _do_send(self, msg: str) -> None:
    106         assert self._writer is not None
    107         self._writer.write(msg.encode() + b'\n')
    108 
    109     async def send_msg(self, msg: str) -> None:
    110         await self._outgoing.put(msg)
    111 
    112 
    113 def run_as_task(coro, allow_cancellation=False):
    114     """
    115     Run a given coroutine as a task.
    116 
    117     Optionally, wrap it in a try..except block that allows this
    118     coroutine to be canceled gracefully.
    119     """
    120     async def _runner():
    121         try:
    122             await coro
    123         except asyncio.CancelledError:
    124             if allow_cancellation:
    125                 return
    126             raise
    127     return create_task(_runner())
    128 
    129 
    130 @contextmanager
    131 def jammed_socket():
    132     """
    133     Opens up a random unused TCP port on localhost, then jams it.
    134     """
    135     socks = []
    136 
    137     try:
    138         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    139         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    140         sock.bind(('127.0.0.1', 0))
    141         sock.listen(1)
    142         address = sock.getsockname()
    143 
    144         socks.append(sock)
    145 
    146         # I don't *fully* understand why, but it takes *two* un-accepted
    147         # connections to start jamming the socket.
    148         for _ in range(2):
    149             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    150             sock.connect(address)
    151             socks.append(sock)
    152 
    153         yield address
    154 
    155     finally:
    156         for sock in socks:
    157             sock.close()
    158 
    159 
    160 class Smoke(avocado.Test):
    161 
    162     def setUp(self):
    163         self.proto = NullProtocol()
    164 
    165     def test__repr__(self):
    166         self.assertEqual(
    167             repr(self.proto),
    168             "<NullProtocol runstate=IDLE>"
    169         )
    170 
    171     def testRunstate(self):
    172         self.assertEqual(
    173             self.proto.runstate,
    174             Runstate.IDLE
    175         )
    176 
    177     def testDefaultName(self):
    178         self.assertEqual(
    179             self.proto.name,
    180             None
    181         )
    182 
    183     def testLogger(self):
    184         self.assertEqual(
    185             self.proto.logger.name,
    186             'qemu.qmp.protocol'
    187         )
    188 
    189     def testName(self):
    190         self.proto = NullProtocol('Steve')
    191 
    192         self.assertEqual(
    193             self.proto.name,
    194             'Steve'
    195         )
    196 
    197         self.assertEqual(
    198             self.proto.logger.name,
    199             'qemu.qmp.protocol.Steve'
    200         )
    201 
    202         self.assertEqual(
    203             repr(self.proto),
    204             "<NullProtocol name='Steve' runstate=IDLE>"
    205         )
    206 
    207 
    208 class TestBase(avocado.Test):
    209 
    210     def setUp(self):
    211         self.proto = NullProtocol(type(self).__name__)
    212         self.assertEqual(self.proto.runstate, Runstate.IDLE)
    213         self.runstate_watcher = None
    214 
    215     def tearDown(self):
    216         self.assertEqual(self.proto.runstate, Runstate.IDLE)
    217 
    218     async def _asyncSetUp(self):
    219         pass
    220 
    221     async def _asyncTearDown(self):
    222         if self.runstate_watcher:
    223             await self.runstate_watcher
    224 
    225     @staticmethod
    226     def async_test(async_test_method):
    227         """
    228         Decorator; adds SetUp and TearDown to async tests.
    229         """
    230         async def _wrapper(self, *args, **kwargs):
    231             loop = asyncio.get_event_loop()
    232             loop.set_debug(True)
    233 
    234             await self._asyncSetUp()
    235             await async_test_method(self, *args, **kwargs)
    236             await self._asyncTearDown()
    237 
    238         return _wrapper
    239 
    240     # Definitions
    241 
    242     # The states we expect a "bad" connect/accept attempt to transition through
    243     BAD_CONNECTION_STATES = (
    244         Runstate.CONNECTING,
    245         Runstate.DISCONNECTING,
    246         Runstate.IDLE,
    247     )
    248 
    249     # The states we expect a "good" session to transition through
    250     GOOD_CONNECTION_STATES = (
    251         Runstate.CONNECTING,
    252         Runstate.RUNNING,
    253         Runstate.DISCONNECTING,
    254         Runstate.IDLE,
    255     )
    256 
    257     # Helpers
    258 
    259     async def _watch_runstates(self, *states):
    260         """
    261         This launches a task alongside (most) tests below to confirm that
    262         the sequence of runstate changes that occur is exactly as
    263         anticipated.
    264         """
    265         async def _watcher():
    266             for state in states:
    267                 new_state = await self.proto.runstate_changed()
    268                 self.assertEqual(
    269                     new_state,
    270                     state,
    271                     msg=f"Expected state '{state.name}'",
    272                 )
    273 
    274         self.runstate_watcher = create_task(_watcher())
    275         # Kick the loop and force the task to block on the event.
    276         await asyncio.sleep(0)
    277 
    278 
    279 class State(TestBase):
    280 
    281     @TestBase.async_test
    282     async def testSuperfluousDisconnect(self):
    283         """
    284         Test calling disconnect() while already disconnected.
    285         """
    286         await self._watch_runstates(
    287             Runstate.DISCONNECTING,
    288             Runstate.IDLE,
    289         )
    290         await self.proto.disconnect()
    291 
    292 
    293 class Connect(TestBase):
    294     """
    295     Tests primarily related to calling Connect().
    296     """
    297     async def _bad_connection(self, family: str):
    298         assert family in ('INET', 'UNIX')
    299 
    300         if family == 'INET':
    301             await self.proto.connect(('127.0.0.1', 0))
    302         elif family == 'UNIX':
    303             await self.proto.connect('/dev/null')
    304 
    305     async def _hanging_connection(self):
    306         with jammed_socket() as addr:
    307             await self.proto.connect(addr)
    308 
    309     async def _bad_connection_test(self, family: str):
    310         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    311 
    312         with self.assertRaises(ConnectError) as context:
    313             await self._bad_connection(family)
    314 
    315         self.assertIsInstance(context.exception.exc, OSError)
    316         self.assertEqual(
    317             context.exception.error_message,
    318             "Failed to establish connection"
    319         )
    320 
    321     @TestBase.async_test
    322     async def testBadINET(self):
    323         """
    324         Test an immediately rejected call to an IP target.
    325         """
    326         await self._bad_connection_test('INET')
    327 
    328     @TestBase.async_test
    329     async def testBadUNIX(self):
    330         """
    331         Test an immediately rejected call to a UNIX socket target.
    332         """
    333         await self._bad_connection_test('UNIX')
    334 
    335     @TestBase.async_test
    336     async def testCancellation(self):
    337         """
    338         Test what happens when a connection attempt is aborted.
    339         """
    340         # Note that accept() cannot be cancelled outright, as it isn't a task.
    341         # However, we can wrap it in a task and cancel *that*.
    342         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    343         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
    344 
    345         state = await self.proto.runstate_changed()
    346         self.assertEqual(state, Runstate.CONNECTING)
    347 
    348         # This is insider baseball, but the connection attempt has
    349         # yielded *just* before the actual connection attempt, so kick
    350         # the loop to make sure it's truly wedged.
    351         await asyncio.sleep(0)
    352 
    353         task.cancel()
    354         await task
    355 
    356     @TestBase.async_test
    357     async def testTimeout(self):
    358         """
    359         Test what happens when a connection attempt times out.
    360         """
    361         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    362         task = run_as_task(self._hanging_connection())
    363 
    364         # More insider baseball: to improve the speed of this test while
    365         # guaranteeing that the connection even gets a chance to start,
    366         # verify that the connection hangs *first*, then await the
    367         # result of the task with a nearly-zero timeout.
    368 
    369         state = await self.proto.runstate_changed()
    370         self.assertEqual(state, Runstate.CONNECTING)
    371         await asyncio.sleep(0)
    372 
    373         with self.assertRaises(asyncio.TimeoutError):
    374             await asyncio.wait_for(task, timeout=0)
    375 
    376     @TestBase.async_test
    377     async def testRequire(self):
    378         """
    379         Test what happens when a connection attempt is made while CONNECTING.
    380         """
    381         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
    382         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
    383 
    384         state = await self.proto.runstate_changed()
    385         self.assertEqual(state, Runstate.CONNECTING)
    386 
    387         with self.assertRaises(StateError) as context:
    388             await self._bad_connection('UNIX')
    389 
    390         self.assertEqual(
    391             context.exception.error_message,
    392             "NullProtocol is currently connecting."
    393         )
    394         self.assertEqual(context.exception.state, Runstate.CONNECTING)
    395         self.assertEqual(context.exception.required, Runstate.IDLE)
    396 
    397         task.cancel()
    398         await task
    399 
    400     @TestBase.async_test
    401     async def testImplicitRunstateInit(self):
    402         """
    403         Test what happens if we do not wait on the runstate event until
    404         AFTER a connection is made, i.e., connect()/accept() themselves
    405         initialize the runstate event. All of the above tests force the
    406         initialization by waiting on the runstate *first*.
    407         """
    408         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
    409 
    410         # Kick the loop to coerce the state change
    411         await asyncio.sleep(0)
    412         assert self.proto.runstate == Runstate.CONNECTING
    413 
    414         # We already missed the transition to CONNECTING
    415         await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
    416 
    417         task.cancel()
    418         await task
    419 
    420 
    421 class Accept(Connect):
    422     """
    423     All of the same tests as Connect, but using the accept() interface.
    424     """
    425     async def _bad_connection(self, family: str):
    426         assert family in ('INET', 'UNIX')
    427 
    428         if family == 'INET':
    429             await self.proto.start_server_and_accept(('example.com', 1))
    430         elif family == 'UNIX':
    431             await self.proto.start_server_and_accept('/dev/null')
    432 
    433     async def _hanging_connection(self):
    434         with TemporaryDirectory(suffix='.qmp') as tmpdir:
    435             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
    436             await self.proto.start_server_and_accept(sock)
    437 
    438 
    439 class FakeSession(TestBase):
    440 
    441     def setUp(self):
    442         super().setUp()
    443         self.proto.fake_session = True
    444 
    445     async def _asyncSetUp(self):
    446         await super()._asyncSetUp()
    447         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
    448 
    449     async def _asyncTearDown(self):
    450         await self.proto.disconnect()
    451         await super()._asyncTearDown()
    452 
    453     ####
    454 
    455     @TestBase.async_test
    456     async def testFakeConnect(self):
    457 
    458         """Test the full state lifecycle (via connect) with a no-op session."""
    459         await self.proto.connect('/not/a/real/path')
    460         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
    461 
    462     @TestBase.async_test
    463     async def testFakeAccept(self):
    464         """Test the full state lifecycle (via accept) with a no-op session."""
    465         await self.proto.start_server_and_accept('/not/a/real/path')
    466         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
    467 
    468     @TestBase.async_test
    469     async def testFakeRecv(self):
    470         """Test receiving a fake/null message."""
    471         await self.proto.start_server_and_accept('/not/a/real/path')
    472 
    473         logname = self.proto.logger.name
    474         with self.assertLogs(logname, level='DEBUG') as context:
    475             self.proto.trigger_input.set()
    476             self.proto.trigger_input.clear()
    477             await asyncio.sleep(0)  # Kick reader.
    478 
    479         self.assertEqual(
    480             context.output,
    481             [f"DEBUG:{logname}:<-- None"],
    482         )
    483 
    484     @TestBase.async_test
    485     async def testFakeSend(self):
    486         """Test sending a fake/null message."""
    487         await self.proto.start_server_and_accept('/not/a/real/path')
    488 
    489         logname = self.proto.logger.name
    490         with self.assertLogs(logname, level='DEBUG') as context:
    491             # Cheat: Send a Null message to nobody.
    492             await self.proto.send_msg()
    493             # Kick writer; awaiting on a queue.put isn't sufficient to yield.
    494             await asyncio.sleep(0)
    495 
    496         self.assertEqual(
    497             context.output,
    498             [f"DEBUG:{logname}:--> None"],
    499         )
    500 
    501     async def _prod_session_api(
    502             self,
    503             current_state: Runstate,
    504             error_message: str,
    505             accept: bool = True
    506     ):
    507         with self.assertRaises(StateError) as context:
    508             if accept:
    509                 await self.proto.start_server_and_accept('/not/a/real/path')
    510             else:
    511                 await self.proto.connect('/not/a/real/path')
    512 
    513         self.assertEqual(context.exception.error_message, error_message)
    514         self.assertEqual(context.exception.state, current_state)
    515         self.assertEqual(context.exception.required, Runstate.IDLE)
    516 
    517     @TestBase.async_test
    518     async def testAcceptRequireRunning(self):
    519         """Test that accept() cannot be called when Runstate=RUNNING"""
    520         await self.proto.start_server_and_accept('/not/a/real/path')
    521 
    522         await self._prod_session_api(
    523             Runstate.RUNNING,
    524             "NullProtocol is already connected and running.",
    525             accept=True,
    526         )
    527 
    528     @TestBase.async_test
    529     async def testConnectRequireRunning(self):
    530         """Test that connect() cannot be called when Runstate=RUNNING"""
    531         await self.proto.start_server_and_accept('/not/a/real/path')
    532 
    533         await self._prod_session_api(
    534             Runstate.RUNNING,
    535             "NullProtocol is already connected and running.",
    536             accept=False,
    537         )
    538 
    539     @TestBase.async_test
    540     async def testAcceptRequireDisconnecting(self):
    541         """Test that accept() cannot be called when Runstate=DISCONNECTING"""
    542         await self.proto.start_server_and_accept('/not/a/real/path')
    543 
    544         # Cheat: force a disconnect.
    545         await self.proto.simulate_disconnect()
    546 
    547         await self._prod_session_api(
    548             Runstate.DISCONNECTING,
    549             ("NullProtocol is disconnecting."
    550              " Call disconnect() to return to IDLE state."),
    551             accept=True,
    552         )
    553 
    554     @TestBase.async_test
    555     async def testConnectRequireDisconnecting(self):
    556         """Test that connect() cannot be called when Runstate=DISCONNECTING"""
    557         await self.proto.start_server_and_accept('/not/a/real/path')
    558 
    559         # Cheat: force a disconnect.
    560         await self.proto.simulate_disconnect()
    561 
    562         await self._prod_session_api(
    563             Runstate.DISCONNECTING,
    564             ("NullProtocol is disconnecting."
    565              " Call disconnect() to return to IDLE state."),
    566             accept=False,
    567         )
    568 
    569 
    570 class SimpleSession(TestBase):
    571 
    572     def setUp(self):
    573         super().setUp()
    574         self.server = LineProtocol(type(self).__name__ + '-server')
    575 
    576     async def _asyncSetUp(self):
    577         await super()._asyncSetUp()
    578         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
    579 
    580     async def _asyncTearDown(self):
    581         await self.proto.disconnect()
    582         try:
    583             await self.server.disconnect()
    584         except EOFError:
    585             pass
    586         await super()._asyncTearDown()
    587 
    588     @TestBase.async_test
    589     async def testSmoke(self):
    590         with TemporaryDirectory(suffix='.qmp') as tmpdir:
    591             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
    592             server_task = create_task(self.server.start_server_and_accept(sock))
    593 
    594             # give the server a chance to start listening [...]
    595             await asyncio.sleep(0)
    596             await self.proto.connect(sock)