跳转至

DistEmbedding 模块测试

代码位置: src/python/pytorch/recstore/unittest/test_dist_emb.py

简介

DistEmbedding 是 RecStore 的核心构建块,它是 nn.Module 的直接子类。此测试验证它完全遵循 PyTorch 的 Autograd 机制,并能无缝接入 Parameter Server。

流程与测试点

测试方法 描述 图解/逻辑 预期结果
test_initialization 初始化与元数据注册 DistEmbedding(...) -> 连接 Client -> 注册 Table Client 内部元数据结构中应包含该 Table 的 Shape 和 Dtype。
test_backward 梯度追踪 (Tracing) loss.backward() -> Hook 触发 -> 记录 {Keys, Grads} 到 Trace Backward 仅记录 Trace,不更新权重。权重在此时应保持不变。
test_optimizer_step 优化器更新 SparseSGD.step() -> 读取 Trace -> client.push 调用 step() 后,权重发生变化,且符合 SGD 更新公式。
test_duplicate_ids 梯度聚合 输入 Batch: [ID 1, ID 1, ID 1]
Backward: Wait...
Server 端收到的 ID 1 的梯度应为 3 次梯度的总和
test_persistence 数据持久化 Emb(name="A") 退出 -> Emb(name="A") 重启 新实例应能读取到上一个实例写入的数据。