RecStore 计算层
概述
计算层负责模型训练和推理中的嵌入向量查询与梯度更新,从参数服务器向上,通过 OP 层提供统一的 C++ 接口,再到 PyTorch Python 绑定,最后接入推荐系统模型代码。
架构分层
| 层级 |
模块 |
文件路径 |
说明 |
| 7 |
推荐模型 |
用户代码 |
DLRM, Wide&Deep 等 |
| 6 |
Embedding 模块 |
src/python/pytorch/torchrec/EmbeddingBag.py |
EmbeddingBag / DistEmbedding |
| 5 |
KV 客户端 |
src/python/pytorch/recstore/KVClient.py |
RecStoreClient 单例 |
| 4 |
PyTorch 扩展 |
src/framework/pytorch/op_torch.cc |
torch.ops.recstore_ops |
| 3 |
C++ 接口 |
src/framework/op.h |
CommonOp / KVClientOp |
| 2 |
PS 客户端 |
src/ps/client/ |
BasePSClient |
| 1 |
PS 服务 |
src/ps/server/ |
gRPC/bRPC 通信 |
| 0 |
存储层 |
src/storage/ |
BaseKV / 引擎 / 内存管理 |
数据流
前向传播
| 步骤 |
组件 |
文件路径 |
代码/操作 |
| 1 |
推荐模型 |
用户代码 |
features = get_batch() |
| 2 |
EmbeddingBag |
src/python/pytorch/torchrec/EmbeddingBag.py |
output = emb_module(features) |
| 3 |
KVClient |
src/python/pytorch/recstore/KVClient.py |
client.pull(name, ids) |
| 4 |
PyTorch 扩展 |
src/framework/pytorch/op_torch.cc |
torch.ops.recstore_ops.emb_read() |
| 5 |
CommonOp |
src/framework/op.h |
op->EmbRead(rec_keys, rec_values) |
| 6 |
BasePSClient |
src/ps/client/ |
与 PS 通信获取向量 |
| 7 |
EmbeddingBag |
src/python/pytorch/torchrec/EmbeddingBag.py |
F.embedding_bag(..., mode="sum") |
| 8 |
返回 |
PyTorch |
[batch_size, num_features, emb_dim] |
反向传播与梯度更新
| 步骤 |
组件 |
文件路径 |
代码/操作 |
| 1 |
用户代码 |
- |
loss.backward() |
| 2 |
EmbeddingBag |
src/python/pytorch/torchrec/EmbeddingBag.py |
触发 _RecStoreEBCFunction.backward |
| 3 |
梯度收集 |
同上 |
将 (ID, grad) 追踪到 _trace |
| 4 |
优化器 |
src/python/pytorch/recstore/optimizer.py |
optimizer.step([emb_module]) |
| 5 |
梯度应用 |
KVClient |
client.update(table_name, ids, grads) |
| 6 |
C++ 接口 |
src/framework/op.h |
op->EmbUpdate(keys, grads) |
| 7 |
PS 客户端 |
src/ps/client/ |
Get → Update → Put |
| 8 |
参数更新完成 |
- |
PS 中的嵌入向量更新 |
异步预取流程
| 步骤 |
组件 |
文件路径 |
代码/操作 |
| 1 |
模型 |
用户代码 |
prefetch_id = client.prefetch(ids) |
| 2 |
KVClient |
src/python/pytorch/recstore/KVClient.py |
返回唯一 prefetch_id |
| 3 |
后台读取 |
src/framework/pytorch/op_torch.cc |
op->EmbPrefetch() 异步执行 |
| 4 |
计算重叠 |
用户代码 |
执行其他计算,不阻塞 |
| 5 |
等待完成 |
KVClient |
result = client.wait_for_prefetch(prefetch_id) |
| 6 |
获取结果 |
C++ 侧 |
op->GetPretchResult() |
| 7 |
返回 |
PyTorch |
[N, embedding_dim] 张量 |
模块说明
详细文档:
配置示例
初始化嵌入表
kv_client = RecStoreClient(library_path="/path/to/lib_recstore_ops.so")
# 初始化嵌入表
kv_client.init_data(
name="user_embedding",
shape=(1000000, 128), # num_embeddings, embedding_dim
dtype=torch.float32,
init_func=lambda shape, dtype: torch.randn(shape, dtype=dtype) * 0.01
)
创建 EmbeddingBag
emb_configs = [
# ...
]
emb_bag = RecStoreEmbeddingBagCollection(
embedding_bag_configs=emb_configs,
lr=0.01,
enable_fusion=True,
fusion_k=30
)