[CI] 【Hackathon 9th Sprint No.41】NO.41 功能模块单测补充 -part by xunyoyo · Pull Request #5062 · PaddlePaddle/FastDeploy (original) (raw)
def test_has_splitwise_tasks_detects_prefill_backlog():
cfg = make_cfg(innode_ports=[7001])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(7001)
queue = connector.connect_innode_instances[7001]
queue.available_prefill_instances.size = 1
assert connector.has_splitwise_tasks() is False
queue.available_prefill_instances.size = 0
assert connector.has_splitwise_tasks() is True
def test_dispatch_innode_splitwise_tasks_promotes_decode_role():
cfg = make_cfg(innode_ports=[8002])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(8002)
queue = connector.connect_innode_instances[8002]
queue.prefill_ready = True
task = make_task("req-dispatch", role="prefill", protocol="ipc")
connector.dispatch_innode_splitwise_tasks([task], current_id=33)
assert queue.disaggregated_tasks[-1][0] == "prefill"
assert task.disaggregate_info["role"] == "decode"
assert task.disaggregate_info["cache_info"]["ipc"]["current_id"] == 33
def test_send_splitwise_tasks_dispatches_when_innode_ports_available():
cfg = make_cfg(innode_ports=[8100])
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(8100)
connector.connect_innode_instances[8100].prefill_ready = True
task = make_task("req-prefill", role="prefill", protocol="ipc")
connector.send_splitwise_tasks([task], current_id=44)
assert connector.connect_innode_instances[8100].disaggregated_tasks
def test_send_splitwise_tasks_innode_rewrites_ports_for_decode_queue():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(8123)
task = make_task("req-innode", role="decode", protocol="ipc")
snapshot_port = connector.send_splitwise_tasks_innode([task], 8123)
recorded = connector.connect_innode_instances[8123].disaggregated_tasks[-1]
assert snapshot_port == 8123
assert (
recorded[1][0].disaggregate_info["cache_info"]["ipc"]["port"]
== cfg.parallel_config.engine_worker_queue_port[0]
)
assert task.disaggregate_info["cache_info"]["ipc"]["port"] == 8123
def test_send_splitwise_tasks_rdma_routes_and_resets_state():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-remote", role="prefill", protocol="rdma")
connector.send_splitwise_tasks([task], current_id=55)
assert connector.sent_messages[-1][0] == "10.1.0.1:9010"
assert connector.sent_messages[-1][1] == "prefill"
assert connector.current_request_ids["req-remote"] == "init"
assert task.disaggregate_info["role"] == "prefill"
def test_send_cache_infos_prefill_batches_into_worker_queue():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-prefill", role="prefill", protocol="ipc")
was_decode = connector.send_cache_infos([task], current_id=11)
assert was_decode is False
assert worker_queue.cache_infos[-1][0]["request_id"] == "req-prefill"
assert worker_queue.cache_infos[-1][0]["current_id"] == 11
def test_send_cache_infos_decode_rdma_triggers_remote_sync():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-decode", role="decode", protocol="rdma")
result = connector.send_cache_infos([task], current_id=22)
assert result is True
assert connector.sent_messages[-1][1] == "cache_sync"
assert worker_queue.cache_infos == []
def test_send_cache_infos_decode_ipc_forwards_to_local_worker():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(9300)
task = make_task("req-local", role="decode", protocol="ipc")
task.disaggregate_info["cache_info"]["ipc"]["port"] = 9300
connector.send_cache_infos([task], current_id=7)
assert connector.connect_innode_instances[9300].cache_infos[-1][0]["transfer_protocol"] == "ipc"
def test_send_cache_infos_rdma_with_error_message_forwards_reason():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-err", role="decode", protocol="rdma")
task.error_msg = "remote boom"
connector.send_cache_infos([task], current_id=0)
assert connector.sent_messages[-1][1] == "cache_sync"
assert "error_msg" in connector.sent_messages[-1][2][0]
def test_send_first_token_to_ipc_decode_queue():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
connector.create_connection(9400)
msg = {"transfer_protocol": "ipc", "cache_info": {"ipc": {"port": 9400}}}
task = make_task("req-first", role="decode", protocol="ipc")
connector.send_first_token(msg, [task])
assert connector.connect_innode_instances[9400].disaggregated_tasks[-1][0] == "decode"
def test_send_first_token_rdma_path(monkeypatch):
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
msg = {
"transfer_protocol": "rdma",
"cache_info": {"rdma": {"ip": "1.2.3.4", "port": 9123}},
}
task = make_task("req-first-rdma", role="decode", protocol="rdma")
connector.send_first_token(msg, task)
assert connector.sent_messages[-1][0] == "1.2.3.4:9123"
assert connector.sent_messages[-1][1] == "decode"
def test_check_decode_allocated_reports_finish_and_error():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
task = make_task("req-finish", role="prefill", protocol="rdma")
connector.current_request_ids["req-finish"] = "finished"
ok, msg = connector.check_decode_allocated(task)
assert ok and msg == ""
task2 = make_task("req-error", role="prefill", protocol="rdma")
connector.current_request_ids["req-error"] = "failed"
ok2, msg2 = connector.check_decode_allocated(task2)
assert ok2 is False and msg2 == "failed"
def test_process_cache_sync_records_status_and_forwards(monkeypatch):
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
payload = [
{"request_id": "req-a", "error_msg": "boom"},
{"request_id": "req-b"},
]
message = json.dumps({"type": "cache_sync", "payload": payload}).encode("utf-8")
connector._process_message(message)
assert connector.current_request_ids["req-a"] == "boom"
assert connector.current_request_ids["req-b"] == "finished"
assert worker_queue.cache_infos[-1] == payload
def test_handle_prefill_and_decode_messages():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
req = make_request_obj("req-handle")
connector._handle_prefill([req.to_dict()])
assert worker_queue.disaggregated_tasks[-1][0] == "decode"
completion = CompletionOutput(index=0, send_idx=0, token_ids=[])
metrics = RequestMetrics(arrival_time=0.0)
output = RequestOutput("req-out", outputs=completion, metrics=metrics)
connector._handle_decode([output.to_dict()])
assert worker_queue.disaggregated_tasks[-1][0] == "decode"
def test_close_connection_removes_socket_reference():
cfg = make_cfg()
worker_queue = FakeEngineWorkerQueue()
connector = InspectableConnector(cfg, worker_queue, object())
class DummySocket:
def __init__(self):
self.closed = False
def close(self):
self.closed = True
dummy = DummySocket()
connector.push_sockets = {"test": dummy}
connector._close_connection("test")
assert dummy.closed is True
assert connector.push_sockets == {}
def test_send_message_initializes_network_and_serializes(monkeypatch):
monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())
class DummyExecutor:
def __init__(self, *_, **__):
self.calls = []
def submit(self, fn, *args, **kwargs):
self.calls.append((fn, args, kwargs))
monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", DummyExecutor)
cfg = make_cfg(pd_comm_port=[9550], enable_expert_parallel=True, data_parallel_size=2, local_data_parallel_id=1)
worker_queue = FakeEngineWorkerQueue()
connector = SplitwiseConnector(cfg, worker_queue, object())
output = RequestOutput("req-zmq")
connector._send_message("127.0.0.1:9551", "decode", [output])
sock = connector.push_sockets["127.0.0.1:9551"]
assert json.loads(sock.sent[-1][1].decode("utf-8"))["type"] == "decode"
def test_send_message_handles_failures_and_resets_socket(monkeypatch):
monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())
monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", lambda *_, **__: None)
cfg = make_cfg(pd_comm_port=[9660])
worker_queue = FakeEngineWorkerQueue()
connector = SplitwiseConnector(cfg, worker_queue, object())
failing_socket = _StubSocket(2)
failing_socket.should_fail = True
connector.push_sockets["node"] = failing_socket
splitwise_connector.main_process_metrics.send_cache_failed_num.value = 0
output = RequestOutput("req-fail")
connector._send_message("node", "decode", [output])
assert "node" not in connector.push_sockets
assert splitwise_connector.main_process_metrics.send_cache_failed_num.value == 1