跳转至

Embedding 与 Optimizer

计算层的主入口是 RecStoreEmbeddingBagCollectionDistEmbedding 是单表封装,适合简单实验和兼容旧训练代码;它的更新路径和 EBC 不一样。

模块 文件 使用场景
RecStoreEmbeddingBagCollection src/python/pytorch/torchrec_kv/EmbeddingBag.py TorchRec 风格多表、多特征模型。当前模型接入优先看它。
DistEmbedding src/python/pytorch/recstore/DistEmb.py 单个 embedding table。通过 DistTensor 读写,optimizer 在 Python 侧做 SGD 写回。

Embedding 模块

RecStoreEmbeddingBagCollection

module = RecStoreEmbeddingBagCollection(
    embedding_bag_configs=configs,
    lr=0.01,
    enable_fusion=True,
    fusion_k=30,
)

输入是 KeyedJaggedTensor。模块按 feature 找到 table,读取 ids 对应的 embedding,再用 torch.nn.functional.embedding_bag 做 sum pooling。

DistEmbedding

emb = DistEmbedding(
    num_embeddings=100000,
    embedding_dim=64,
    name="user_embedding",
    init_func=None,
    lr=0.01,
)

forward(ids) 通过 DistTensor.__getitem__client.pull。backward 会对重复 id 做 torch.unique + index_add_,把聚合后的 (ids, grads) 写入 _trace

fused id

开启 fusion 后,表前缀写进高位:fused_id = raw_id + (table_idx << fusion_k)

fusion_k 时要确认 raw id 不会覆盖表前缀。

预取

handle = module.issue_fused_prefetch(features)
module.set_fused_prefetch_handle(handle)
out = module(features)

issue_fused_prefetch 去重 fused ids 并返回 handle。forward 看到 _fused_prefetch_handle 后等待结果,再按 inverse index 恢复原始顺序。没有 handle 时走同步读取。

梯度和 optimizer

Python 侧 optimizer 先消费模块 _trace。EBC 和 DistEmbedding 后续更新位置不同。

顺序 操作 文件
1 forward 读取 embedding。 torchrec_kv/EmbeddingBag.pyrecstore/DistEmb.py
2 backward 写入 module._trace 同上
3 SparseOptimizer.step() 读取并聚合 trace。 recstore/optimizer.py
4 EBC 调 RecStoreClient.update_async(table, ids, grads)DistEmbedding 读当前权重并在 Python 侧计算新值。 recstore/optimizer.py
5 EBC 进入 KVClientOp::EmbUpdate 和 PS optimizer;DistEmbedding 写回 DistTensor src/framework/op.ccsrc/optimizer/optimizer.cpprecstore/DistTensor.py
模块 trace item
RecStoreEmbeddingBagCollection (table_name, ids, grads)
DistEmbedding (ids, grads)

SparseOptimizer.step() 处理后会清理 trace。跳过这一步会让下一轮 batch 混入旧梯度。

EBC 更新

SparseSGD 聚合 _trace 后调用 kv_client.update_async(name, ids, grads)。学习率和优化器公式由后端处理。local_shm / HierKV 快路径会走 local_update_flat

DistEmbedding 更新

SparseSGD 会合并重复 id,读取 mod.weight[unique_ids],在 Python 侧执行 current - lr * grad,再写回 DistTensor。这条路径没有使用后端 optimizer 状态。

后端 optimizer

后端实现位于 src/optimizer/optimizer.cpp。SGD 无额外状态;AdaGrad 和 RowWiseAdaGrad 的累积状态都在后端维护。EBC 的远端更新应保持这个边界。

修改检查

改 trace 结构:同步更新 SparseOptimizer 和相关单测。

改后端 optimizer:补跑 EBC / PS 更新路径测试,确认 read-modify-write 顺序。

DistEmbedding:检查 DistTensor.__getitem__ / __setitem__ 和 Python 侧 SGD 聚合。

改本地快路径:检查 local_update_flat 和 RPC fallback。