Pytorch在FSDP模型中使用EMA
注本文章方法只适用切分策略为SHARDED_STATE_DICT场景。使用FSDP对模型权重切分后如何使用EMA网上搜了一圈没找到个一个靠谱的办法干脆自己写一个算了实现代码如下FSDP1版本实现在torch2.1版本测试过。importosfromtypingimportDict,Listfromcollectionsimportdefaultdictimporttorchfromtorch.distributed.fsdpimportFullyShardedDataParallelasFSDP,StateDictTypeimporttorch.distributed.checkpointasdist_cpfromtorch.distributed.checkpoint.default_plannerimportDefaultSavePlannerclassShardEMAModel:def__init__(self,fsdp_model:FSDP,decay:float0.999):assertisinstance(fsdp_model,FSDP)self.fsdp_modelfsdp_model self.decaydecay self.shard_ema_state:Dict[str,List[torch.Tensor]]defaultdict(list)shard_stateself._get_shard_state()fork,vinshard_state.items():forlocal_shardinv._local_shards:self.shard_ema_state[k].append(local_shard.tensor.clone())self.num_shard_paramssum([sum([t.numel()fortinv])forvinself.shard_ema_state.values()])print(fShard EMA Model has{self.num_shard_params/1e6:.3f}M params.)def_get_shard_state(self):withFSDP.state_dict_type(self.fsdp_model,StateDictType.SHARDED_STATE_DICT):shard_stateself.fsdp_model.state_dict()returnshard_statetorch.inference_mode()defupdate(self):update EMA Model shard weightsshard_stateself._get_shard_state()fork,vinshard_state.items():foridx,local_shardinenumerate(v._local_shards):self.shard_ema_state[k][idx].mul_(self.decay).add_(local_shard.tensor,alpha1-self.decay)defsave_ema_shard_weights(self,save_dir:str):save EMA Model shard weightswithFSDP.state_dict_type(self.fsdp_model,StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir,exist_okTrue)shard_stateself.fsdp_model.state_dict()fork,vinshard_state.items():foridx,local_shardinenumerate(v._local_shards):local_shard.tensorself.shard_ema_state[k][idx]state_dict{model:shard_state}dist_cp.save(state_dictstate_dict,storage_writerdist_cp.FileSystemWriter(save_dir),plannerDefaultSavePlanner(),)defsave_shard_weights(self,save_dir:str):save original FSDP Model shard weightswithFSDP.state_dict_type(self.fsdp_model,StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir,exist_okTrue)shard_stateself.fsdp_model.state_dict()state_dict{model:shard_state}dist_cp.save(state_dictstate_dict,storage_writerdist_cp.FileSystemWriter(save_dir),plannerDefaultSavePlanner(),)FSDP2版本实现torch版本需大于等于2.9.0暂未测试。importosfromcollectionsimportOrderedDictimporttorchimporttorch.distributed.checkpointasdcpfromtorch.distributed.tensorimportDTensorfromtorch.distributed.fsdpimportFSDPModuleclassShardEMAModel:def__init__(self,fsdp_model:FSDPModule,decay:float0.999):assertisinstance(fsdp_model,FSDPModule)self.fsdp_modelfsdp_model self.decaydecay self.shard_ema_state:OrderedDict[str,DTensor]self.fsdp_model.state_dict()self.num_shard_paramssum([v.numel()forvinself.shard_ema_state.values()])print(fShard EMA Model has{self.num_shard_params/1e6:.3f}M params.)torch.inference_mode()defupdate(self):update EMA Model shard weightsshard_stateself.fsdp_model.state_dict()fork,vinshard_state.items():self.shard_ema_state[k].mul_(self.decay).add_(v,alpha1-self.decay)defsave_ema_shard_weights(self,save_dir:str):save EMA Model shard weightsos.makedirs(save_dir,exist_okTrue)state_dict{model:self.shard_ema_state}dcp.save(state_dictstate_dict,checkpoint_idsave_dir)defsave_shard_weights(self,save_dir:str):save original FSDP Model shard weightsos.makedirs(save_dir,exist_okTrue)shard_stateself.fsdp_model.state_dict()state_dict{model:shard_state}dcp.save(state_dictstate_dict,checkpoint_idsave_dir)使用示例# create FSDP Model and EMA Modelfsdp_modelFSDP(...)ema_modelShardEMAModel(fsdp_model,decay0.99)# train fsdp model and optimizer weights...# update EMA Model shard weightsema_model.update()# save EMA Model shard weightsema_model.save_ema_shard_weights(save_path)