[Paper] FedNI: Federated Graph Learning with Network Inpainting for Population-Based Disease Prediction
“FedNI: Federated Graph Learning with Network Inpainting for Population-Based Disease Prediction”이란 논문에 대한 리뷰입니다.
원문은 링크에서 확인할 수 있습니다.
Background
기존엔 GCN으로 Feature extraction진행 후 Classifier 진행 본 방식에선 GCN과 Classifier 모두 Federated 방식으로 학습 진행 -Optimization 방식이 Fed인 것이고 각각이 학습하는 건 task에 맞춰 정해진 NI
Part 1: Network Inpainting
: hidden node와 edge를 찾는 것
기존의 GCN은 graph에 대한 데이터가 complete하다는 가정을 하지만 실제론 incomplete하니 incomplete data 환경을 만들고 complete하게 예측하는 것을 학습해야 한다.
-
Hidden node를 정하기 BFS 방식으로 sub-tree를 만들고, leaf node를 hidden node를 정하게 된다.
-
Open node로 feature extraction 진행 GCN으로 open node의 embedding 생성 : Generator
중요한 것은 GCN에서 앞의 encoder만 사용하고 목적에 따라 following layer 정의 -
Hidden node 수 예측 open node의 embedding으로 hidden node의 개수를 예측하는 MLP
-
Hidden node feature 예측 Open node의 embedding으로 hidden node의 embedding을 예측하는 MLP
-
Hidden node feature의 다른 의미의 reconstruction loss 정의 기존의 reconstruction loss의 경우 Input vs decoder 차이를 줄이지만, Hidden node의 실제 embedding과 open node의 embedding으로 예측한 것(decoder의 결과)의 차이를 줄인다
- 그리고 예측된 hidden embedding의 분포와 실제 hidden embedding의 분포 차이를 줄임
SN-GAN의 Loss를 이용하기 위해 discriminator를 사용하고 결과적으로 discriminator를 만들고 이를 가지고 generator loss 정의
- 즉 Generator의 loss(얼마나 hidden feature를 예측하는가)는 GAN으로 학습된 discriminator 계산
- 최종적으로 Reconstruction loss와 Generator loss의 weighted sum
- 더 잘 예측하기 위해 다른 데이터(non-image data)를 사용 Phenotype data를 예측하는 모델 생성
Part 2: Federated Learning
NI as Fed
여기서 달라지는 것은 local과 global로 나뉘게 된다.
앞서 NI를 통해 local들의 Generator와 Discriminator를 학습하고, 이걸 Federate방식으로 aggregation step이 들어가는데 단 Discriminator는 Local한 상태로 둔다(NO FED)
이건 그냥 본인의 graph와 유사한지 안유사한지 구분하는 것이지 다른 graph의 information이 필요한 게 아니라 local graph의 heterogeneity를 알아야한다. 사실 이건 feature를 정확히 만드는 것이 중요한 건데 결국 이건 local에서 embedding을 어떻게 정교하게 만드는지를 다루는 것이니 global이 필요없게 되는 것이다.
Generator의 경우 centralized환경과 비교했을 때 다른 partition과 연결될 수 있는 node를 찾아야하기 때문에 global information이 필요하게 되어 Fed가 된다. 이건 궁극적으로 ML task를 할 때 사용되는 데이터이니 global distribution을 따르게 만드는 것.
- 이 과정을 통해 하나의 전체 graph를 예측하는 과정이 끝난다.
Classifier as Fed
이렇게 모인 것을 또 classifier를 학습할 때에도 Fed를 쓴다.
- 이건 server가 classifier를 하는 것이 아니라 local에서 각각 classifier를 학습하기 위한 것
- 즉, Fed는 결국 서버간에 정보공유 없이 일을 수행하되 align시키는 server-server schema
- 우리가 하는 것은 서버 아래에 client들간 align 시키는 server-client schema
Part 3: Application for our study
여기서 다루는 건 1) “서버급” edge들이 어떻게 align된 성능을 낼 것인지와 2) ML task에 맞춘 edge modeling
우리가 적용할 수 있는 부분이라면, NI가 끝나고 graph merge하는 단계에서 하나의 통일된 전체 graph를 만들 때 align시키는 것, 그리고 서버급 edge 아래에 client간 data를 align시키는 것
-
Generator를 FedAvg하는 것이 아니라 pseudo embedding으로 알아서 조정되게 만드는 방식 (개인적인 고민 1: 기존의 방식이라면 Reconstruction loss 뒤에 term을 추가해야 하는데, 이건 hidden node의 예측값을 align하는 것. 그렇다면 Open node부터 align되어야 할텐데 align한다면 그건 GCN을 학습할 때 해야하는 것인가? 즉 한번만 align시키면 될까?)
-
전체 분포를 사용하는 건 discriminator를 학습할 때도 사용되는데 이 때에도 align시키는 내용이 들어가야 하는가? 즉, (환자-병원) – (병원-환자) 급으로 진행하는 schema를 상상할 수 있지 않을까
댓글남기기