[CI] 【Hackathon 9th Sprint No.41】NO.41 功能模块单测补充 -part by xunyoyo · Pull Request #5062 · PaddlePaddle/FastDeploy (original) (raw)

def test_has_splitwise_tasks_detects_prefill_backlog():

cfg = make_cfg(innode_ports=[7001])

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

connector.create_connection(7001)

queue = connector.connect_innode_instances[7001]

queue.available_prefill_instances.size = 1

assert connector.has_splitwise_tasks() is False

queue.available_prefill_instances.size = 0

assert connector.has_splitwise_tasks() is True

def test_dispatch_innode_splitwise_tasks_promotes_decode_role():

cfg = make_cfg(innode_ports=[8002])

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

connector.create_connection(8002)

queue = connector.connect_innode_instances[8002]

queue.prefill_ready = True

task = make_task("req-dispatch", role="prefill", protocol="ipc")

connector.dispatch_innode_splitwise_tasks([task], current_id=33)

assert queue.disaggregated_tasks[-1][0] == "prefill"

assert task.disaggregate_info["role"] == "decode"

assert task.disaggregate_info["cache_info"]["ipc"]["current_id"] == 33

def test_send_splitwise_tasks_dispatches_when_innode_ports_available():

cfg = make_cfg(innode_ports=[8100])

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

connector.create_connection(8100)

connector.connect_innode_instances[8100].prefill_ready = True

task = make_task("req-prefill", role="prefill", protocol="ipc")

connector.send_splitwise_tasks([task], current_id=44)

assert connector.connect_innode_instances[8100].disaggregated_tasks

def test_send_splitwise_tasks_innode_rewrites_ports_for_decode_queue():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

connector.create_connection(8123)

task = make_task("req-innode", role="decode", protocol="ipc")

snapshot_port = connector.send_splitwise_tasks_innode([task], 8123)

recorded = connector.connect_innode_instances[8123].disaggregated_tasks[-1]

assert snapshot_port == 8123

assert (

recorded[1][0].disaggregate_info["cache_info"]["ipc"]["port"]

== cfg.parallel_config.engine_worker_queue_port[0]

)

assert task.disaggregate_info["cache_info"]["ipc"]["port"] == 8123

def test_send_splitwise_tasks_rdma_routes_and_resets_state():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

task = make_task("req-remote", role="prefill", protocol="rdma")

connector.send_splitwise_tasks([task], current_id=55)

assert connector.sent_messages[-1][0] == "10.1.0.1:9010"

assert connector.sent_messages[-1][1] == "prefill"

assert connector.current_request_ids["req-remote"] == "init"

assert task.disaggregate_info["role"] == "prefill"

def test_send_cache_infos_prefill_batches_into_worker_queue():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

task = make_task("req-prefill", role="prefill", protocol="ipc")

was_decode = connector.send_cache_infos([task], current_id=11)

assert was_decode is False

assert worker_queue.cache_infos[-1][0]["request_id"] == "req-prefill"

assert worker_queue.cache_infos[-1][0]["current_id"] == 11

def test_send_cache_infos_decode_rdma_triggers_remote_sync():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

task = make_task("req-decode", role="decode", protocol="rdma")

result = connector.send_cache_infos([task], current_id=22)

assert result is True

assert connector.sent_messages[-1][1] == "cache_sync"

assert worker_queue.cache_infos == []

def test_send_cache_infos_decode_ipc_forwards_to_local_worker():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

connector.create_connection(9300)

task = make_task("req-local", role="decode", protocol="ipc")

task.disaggregate_info["cache_info"]["ipc"]["port"] = 9300

connector.send_cache_infos([task], current_id=7)

assert connector.connect_innode_instances[9300].cache_infos[-1][0]["transfer_protocol"] == "ipc"

def test_send_cache_infos_rdma_with_error_message_forwards_reason():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

task = make_task("req-err", role="decode", protocol="rdma")

task.error_msg = "remote boom"

connector.send_cache_infos([task], current_id=0)

assert connector.sent_messages[-1][1] == "cache_sync"

assert "error_msg" in connector.sent_messages[-1][2][0]

def test_send_first_token_to_ipc_decode_queue():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

connector.create_connection(9400)

msg = {"transfer_protocol": "ipc", "cache_info": {"ipc": {"port": 9400}}}

task = make_task("req-first", role="decode", protocol="ipc")

connector.send_first_token(msg, [task])

assert connector.connect_innode_instances[9400].disaggregated_tasks[-1][0] == "decode"

def test_send_first_token_rdma_path(monkeypatch):

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

msg = {

"transfer_protocol": "rdma",

"cache_info": {"rdma": {"ip": "1.2.3.4", "port": 9123}},

}

task = make_task("req-first-rdma", role="decode", protocol="rdma")

connector.send_first_token(msg, task)

assert connector.sent_messages[-1][0] == "1.2.3.4:9123"

assert connector.sent_messages[-1][1] == "decode"

def test_check_decode_allocated_reports_finish_and_error():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

task = make_task("req-finish", role="prefill", protocol="rdma")

connector.current_request_ids["req-finish"] = "finished"

ok, msg = connector.check_decode_allocated(task)

assert ok and msg == ""

task2 = make_task("req-error", role="prefill", protocol="rdma")

connector.current_request_ids["req-error"] = "failed"

ok2, msg2 = connector.check_decode_allocated(task2)

assert ok2 is False and msg2 == "failed"

def test_process_cache_sync_records_status_and_forwards(monkeypatch):

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

payload = [

{"request_id": "req-a", "error_msg": "boom"},

{"request_id": "req-b"},

]

message = json.dumps({"type": "cache_sync", "payload": payload}).encode("utf-8")

connector._process_message(message)

assert connector.current_request_ids["req-a"] == "boom"

assert connector.current_request_ids["req-b"] == "finished"

assert worker_queue.cache_infos[-1] == payload

def test_handle_prefill_and_decode_messages():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

req = make_request_obj("req-handle")

connector._handle_prefill([req.to_dict()])

assert worker_queue.disaggregated_tasks[-1][0] == "decode"

completion = CompletionOutput(index=0, send_idx=0, token_ids=[])

metrics = RequestMetrics(arrival_time=0.0)

output = RequestOutput("req-out", outputs=completion, metrics=metrics)

connector._handle_decode([output.to_dict()])

assert worker_queue.disaggregated_tasks[-1][0] == "decode"

def test_close_connection_removes_socket_reference():

cfg = make_cfg()

worker_queue = FakeEngineWorkerQueue()

connector = InspectableConnector(cfg, worker_queue, object())

class DummySocket:

def __init__(self):

self.closed = False

def close(self):

self.closed = True

dummy = DummySocket()

connector.push_sockets = {"test": dummy}

connector._close_connection("test")

assert dummy.closed is True

assert connector.push_sockets == {}

def test_send_message_initializes_network_and_serializes(monkeypatch):

monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())

class DummyExecutor:

def __init__(self, *_, **__):

self.calls = []

def submit(self, fn, *args, **kwargs):

self.calls.append((fn, args, kwargs))

monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", DummyExecutor)

cfg = make_cfg(pd_comm_port=[9550], enable_expert_parallel=True, data_parallel_size=2, local_data_parallel_id=1)

worker_queue = FakeEngineWorkerQueue()

connector = SplitwiseConnector(cfg, worker_queue, object())

output = RequestOutput("req-zmq")

connector._send_message("127.0.0.1:9551", "decode", [output])

sock = connector.push_sockets["127.0.0.1:9551"]

assert json.loads(sock.sent[-1][1].decode("utf-8"))["type"] == "decode"

def test_send_message_handles_failures_and_resets_socket(monkeypatch):

monkeypatch.setattr(splitwise_connector, "zmq", _make_stub_zmq())

monkeypatch.setattr(splitwise_connector, "ThreadPoolExecutor", lambda *_, **__: None)

cfg = make_cfg(pd_comm_port=[9660])

worker_queue = FakeEngineWorkerQueue()

connector = SplitwiseConnector(cfg, worker_queue, object())

failing_socket = _StubSocket(2)

failing_socket.should_fail = True

connector.push_sockets["node"] = failing_socket

splitwise_connector.main_process_metrics.send_cache_failed_num.value = 0

output = RequestOutput("req-fail")

connector._send_message("node", "decode", [output])

assert "node" not in connector.push_sockets

assert splitwise_connector.main_process_metrics.send_cache_failed_num.value == 1