跳转至

CommonOp 接口

概述

CommonOp 是 C++ 侧的统一接口,定义了所有嵌入操作的规范。位于 src/framework/op.h

核心接口

初始化

virtual void EmbInit(const RecTensor& keys, const RecTensor& init_values) = 0;
virtual void EmbInit(const RecTensor& keys, const InitStrategy& strategy) = 0;
virtual bool InitEmbeddingTable(const std::string& table_name, 
                                const EmbeddingTableConfig& config) = 0;
virtual void SetPSConfig(const std::string& host, int port) = 0;
方法 说明
EmbInit(keys, values) 用提供的值初始化嵌入
EmbInit(keys, strategy) 根据策略 (Normal/Uniform/Xavier/Zero) 初始化
InitEmbeddingTable 创建新的嵌入表
SetPSConfig 动态设置 PS 连接配置 (Host/Port)

InitStrategy 支持

类型 参数
Normal mean, std
Uniform lower, upper
Xavier 自动计算
Zero 全零

同步读写

virtual void EmbRead(const RecTensor& keys, RecTensor& values) = 0;
virtual void EmbWrite(const RecTensor& keys, const RecTensor& values) = 0;
方法 说明 使用场景
EmbRead 批量读取嵌入向量 前向传播
EmbWrite 批量写入嵌入向量 初始化、检查点加载

参数

参数 类型 形状 说明
keys uint64 张量 [N] 嵌入 ID
values float32 张量 [N, D] 嵌入向量

异步预取

virtual uint64_t EmbPrefetch(const RecTensor& keys, const RecTensor& values) = 0;
virtual bool IsPrefetchDone(uint64_t prefetch_id) = 0;
virtual void WaitForPrefetch(uint64_t prefetch_id) = 0;
virtual void GetPretchResult(uint64_t prefetch_id, 
                             std::vector<std::vector<float>>* values) = 0;
方法 说明
EmbPrefetch 异步读取,返回唯一 ID
IsPrefetchDone 检查预取是否完成
WaitForPrefetch 阻塞等待预取完成
GetPretchResult 获取预取结果

工作流程

步骤 代码/操作 说明
1 pid = EmbPrefetch(keys, dummy_values) 发起异步预取
2 后台执行 不阻塞当前线程
3 while not IsPrefetchDone(pid): continue 轮询检查状态
4 WaitForPrefetch(pid) 阻塞等待完成
5 GetPretchResult(pid, &results) 获取预取结果
6 返回 vector<vector<float>> 返回向量数据

异步写入

virtual uint64_t EmbWriteAsync(const RecTensor& keys, 
                               const RecTensor& values) = 0;
virtual bool IsWriteDone(uint64_t write_id) = 0;
virtual void WaitForWrite(uint64_t write_id) = 0;
方法 说明
EmbWriteAsync 异步写入,返回写入 ID (未实现)
IsWriteDone 检查写入是否完成 (未实现)
WaitForWrite 阻塞等待写入完成 (未实现)

梯度更新

virtual void EmbUpdate(const RecTensor& keys, const RecTensor& grads) = 0;
virtual void EmbUpdate(const std::string& table_name, 
                       const RecTensor& keys, 
                       const RecTensor& grads) = 0;
方法 说明
EmbUpdate(keys, grads) 对所有表的指定 ID 更新梯度
EmbUpdate(table_name, keys, grads) 对特定表的指定 ID 更新梯度

梯度更新流程

步骤 代码位置 说明
1 src/framework/op.h:CommonOp 调用 EmbUpdate 将 (ID, grad) 送往 PS
2 src/ps/client/ PS 客户端与服务器通信
3 src/ps/server/ PS 内部执行 SGD: emb -= lr * grad
4 src/storage/ 更新后的嵌入向量存入存储层
5 后续读取 自动返回最新的嵌入向量

存在性检查与删除

virtual bool EmbExists(const RecTensor& keys) = 0;
virtual void EmbDelete(const RecTensor& keys) = 0;
方法 说明
EmbExists 检查 ID 是否存在
EmbDelete 删除指定 ID 的嵌入

持久化

virtual void SaveToFile(const std::string& path) = 0;
virtual void LoadFromFile(const std::string& path) = 0;
方法 说明
SaveToFile 将嵌入检查点保存到文件
LoadFromFile 从文件加载嵌入

KVClientOp 实现

KVClientOpCommonOp 的标准实现,通过 BasePSClient 与参数服务器通信。

成员变量

成员 类型 说明
embedding_dim_ int64_t 嵌入向量维度
ps_client_ BasePSClient* 参数服务器客户端

创建实例

std::shared_ptr<CommonOp> op = GetKVClientOp();

数据类型

RecTensor

用于跨 C++ 和 Python 传递数据的张量包装:

字段 说明
data_ptr 数据缓冲区指针
shape 张量形状
dtype 数据类型 (UINT64, FLOAT32)

InitStrategy

字段 说明
type InitStrategyType 枚举
mean, std Normal 初始化参数
lower, upper Uniform 初始化参数

EmbeddingTableConfig

定义嵌入表的配置(详见参数服务器文档)。

线程安全

KVClientOp 可在多线程环境下安全使用:

  • 每个线程可独立调用接口
  • 参数服务器客户端内部处理线程同步