Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【OSCP】 在 SecretFlow 中添加基于torch后端的fed_pa​​c策略 #1276

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

Dorothy21650
Copy link

@Dorothy21650 Dorothy21650 commented Apr 28, 2024

Type of change

  • Add new papers (Please tell us why you think this paper is awesome!)
  • Fix the category of an existing paper/papers (Please tell us the reasons)
  • [ x] Add a new tool/primitive/application with a new markdown page (Thank you! Also, please tell us more about this awesome thing!)

Description

Fixed #1280

Copy link

github-actions bot commented Apr 28, 2024

CLA Assistant Lite bot All contributors have signed the CLA ✍️ ✅

@Dorothy21650
Copy link
Author

I have read the CLA Document and I hereby sign the CLA

1 similar comment
@wyanbing
Copy link

I have read the CLA Document and I hereby sign the CLA

Copy link
Contributor

@zhaocaibei123 zhaocaibei123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个单测,参考 tests/ml/nn/fl/test_fl_model_torch.py

secretflow/ml/nn/core/torch/module.py Outdated Show resolved Hide resolved
from torch.utils.data import DataLoader


def cifar10(stage='train'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据集处理写到单测,或者测试脚本

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 挪到单测里使用custom databuilder模式,或者直接使用datasets里面的逻辑secretflow/utils/simulation/datasets.py

@@ -49,6 +51,7 @@ def __init__(
self.train_set = None
self.eval_set = None
self.skip_bn = skip_bn
# self.dataset_size = torch.tensor(len(self.train_set.dataset)).to(self.exe_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用代码删掉

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -214,11 +217,11 @@ def next_batch(self, stage="train"):
def get_rows_count(self, filename):
return int(rows_count(filename=filename)) - 1 # except header line

def get_weights(self, return_numpy=True):
def get_weights(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么删掉参数 return_numpy ?

secretflow/ml/nn/fl/backend/torch/fl_base.py Outdated Show resolved Hide resolved
self.num_classes = kwargs.get("num_classes", 10)
self.criterion = nn.CrossEntropyLoss()
self.local_model = self.model
self.last_model = deepcopy(self.model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.last_model 用来?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是照搬源码的,源码是在本地训练开始之前保存的模型的深度副本,可能是考虑到了回滚,但是源码中也没有明确使用 self.last_model,现在已经删除掉这行了。

from copy import deepcopy


class FedPACTorchModel(BaseTorchModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类可以合到 strategy/fed_pac.py FedPAC

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

lr = kwargs.get("lr", 0.01)

epoch_classifier = 1
optimizer = torch.optim.SGD(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimizer loss 之类的可以写到 TorchModel 里吗

momentum=0.5,
weight_decay=0.0005,
)
for ep in range(epoch_classifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每个 step 训练一个 epoch 吗

from secretflow.security.aggregation.aggregator import Aggregator


class FedPACAggregator(Aggregator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

average sum 逻辑没变的话可以继承 PlainAggregator

@Dorothy21650 Dorothy21650 changed the title add fed_pac strategy based on torch backend 【OSCP】 在 SecretFlow 中添加基于torch后端的fed_pa​​c策略 May 10, 2024
@@ -216,9 +219,9 @@ def get_rows_count(self, filename):

def get_weights(self, return_numpy=True):
if self.skip_bn:
return self.model.get_weights_not_bn(return_numpy=return_numpy)
return self.model.get_weights_not_bn(return_numpy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 return_numpy 为什么写死?

@@ -23,7 +23,6 @@

from .mixins import ParametersMixin


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式化配置一下,这里不要做修改

return dataset


def batch_sampler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_sampler已有实现,直接调用即可secretflow/ml/nn/fl/backend/torch/sampler.py

from torch.utils.data import DataLoader


def cifar10(stage='train'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 挪到单测里使用custom databuilder模式,或者直接使用datasets里面的逻辑secretflow/utils/simulation/datasets.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

在 SecretFlow 中添加基于torch后端的fed_pa​​c策略
5 participants