Skip to content

All in One: Multi-task Prompting for Graph Neural Networks, KDD 2023.


Notifications You must be signed in to change notification settings


Repository files navigation

🌟ProG: A Unified Python Library for Graph Prompting

ProG (Prompt Graph) is a library built upon PyTorch to easily conduct single or multi-task prompting for pre-trained Graph Neural Networks (GNNs). The original idea is derived from the paper: Xiangguo Sun, Hong Cheng, Jia Li, etc. All in One: Multi-task Prompting for Graph Neural Networks. KDD2023 (🔥 Best Research Paper Award, which is the first time for Hong Kong and Mainland China, see in the ori branch of this repository). Beyond All in One, this library supports more graph prompt models. A part of the integrated models are as follows (keep updating):

  • [All in One] X. Sun, H. Cheng, J. Li, B. Liu, and J. Guan, “All in One: Multi-Task Prompting for Graph Neural Networks,” KDD, 2023
  • [GPF Plus] T. Fang, Y. Zhang, Y. Yang, C. Wang, and L. Chen, “Universal Prompt Tuning for Graph Neural Networks,” NeurIPS, 2023.
  • [GraphPrompt] Liu Z, Yu X, Fang Y, et al. Graphprompt: Unifying pre-training and downstream tasks for graph neural networks. The Web Conference, 2023.
  • [GPPT] M. Sun, K. Zhou, X. He, Y. Wang, and X. Wang, “GPPT: Graph Pre-Training and Prompt Tuning to Generalize Graph Neural Networks,” KDD, 2022
  • [GPF] T. Fang, Y. Zhang, Y. Yang, and C. Wang, “Prompt tuning for graph neural networks,” arXiv preprint, 2022.


  • 2024/05/28: We are so happy to announce that we have finished most updating works for ProG! (the main branch of this repository. If you wish to find the original ProG package, go to the ori branch)
  • 2024/01/01: A big updated version released!
  • 2023/11/28: We released a comprehensive survey on graph prompt!
  • 2023/11/15: We released a repository🦀 for a comprehensive collection of research papers, datasets, and readily accessible code implementations.

Quick Start

The Architecture of ProG is shown as follows:

We have provided scripts with hyper-parameter settings to get the experimental results

In the pre-train phase, you can obtain the experimental results by running the parameters you want:

python --task Edgepred_Gprompt --dataset_name 'PubMed' --gnn_type 'GCN' --hid_dim 128 --num_layer 3 --epochs 50 --seed 42 --device 5

or run

cd scripts

In downstream_task, you can obtain the experimental results by running the parameters you want:

python --pre_train_path 'None' --task GraphTask --dataset_name 'MUTAG' --gnn_type 'GCN' --prompt_type 'None' --shot_num 10 --hid_dim 128 --num_layer 3 --epochs 50 --seed 42 --device 5

or run for Graph task in MUTAG dataset, or run run for Node task in Cora dataset.

Table of The Following Contents
  1. Pre-train your GNN model
  2. Downstream Tasks
  3. Datasets
  4. Prompt Class
  5. Environment Setup
  6. TODO List

Pre-train your GNN model

We have designed four pre_trained class (Edgepred_GPPT, Edgepred_Gprompt, GraphCL, SimGRACE), which is in ProG.pretrain module, you can pre_train the model by running and setting the parameters you want. Or just unzip to get our dataset pretrained model which is already pre-trained.

import prompt_graph as ProG
from ProG.pretrain import Edgepred_GPPT, Edgepred_Gprompt, GraphCL, SimGRACE, NodePrePrompt, GraphPrePrompt, DGI, GraphMAE
from ProG.utils import seed_everything
from ProG.utils import mkdir, get_args
from import load4node,load4graph

args = get_args()

if args.task == 'SimGRACE':
    pt = SimGRACE(dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device)
if args.task == 'GraphCL':
    pt = GraphCL(dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device)
if args.task == 'Edgepred_GPPT':
    pt = Edgepred_GPPT(dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device)
if args.task == 'Edgepred_Gprompt':
    pt = Edgepred_Gprompt(dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device)
if args.task == 'DGI':
    pt = DGI(dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device)
if args.task == 'NodeMultiGprompt':
    nonlinearity = 'prelu'
    pt = NodePrePrompt(args.dataset_name, args.hid_dim, nonlinearity, 0.9, 0.9, 0.1, 0.001, 1, 0.3, args.device)
if args.task == 'GraphMultiGprompt':
    nonlinearity = 'prelu'
    pt = GraphPrePrompt(graph_list, input_dim, out_dim, args.dataset_name, args.hid_dim, nonlinearity,0.9,0.9,0.1,1,0.3, 0.1, args.device)
if args.task == 'GraphMAE':
    pt = GraphMAE(dataset_name = args.dataset_name, gnn_type = args.gnn_type, hid_dim = args.hid_dim, gln = args.num_layer, num_epoch=args.epochs, device=args.device,
                  mask_rate=0.75, drop_edge_rate=0.0, replace_rate=0.1, loss_fn='sce', alpha_l=2)

Load Data

Before we do the downstream task, we need to load the nessary data. For some specific prompt, we need to choose function load_induced_graph to the input of our tasker

def load_induced_graph(dataset_name, data, device):

    folder_path = './Experiment/induced_graph/' + dataset_name
    if not os.path.exists(folder_path):

    file_path = folder_path + '/induced_graph_min100_max300.pkl'
    if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                print('loading induced graph...')
                graphs_list = pickle.load(f)
        print('Begin split_induced_graphs.')
        split_induced_graphs(data, folder_path, device, smallest_size=100, largest_size=300)
        with open(file_path, 'rb') as f:
            graphs_list = pickle.load(f)
    graphs_list = [ for graph in graphs_list]
    return graphs_list

args = get_args()

print('dataset_name', args.dataset_name)
if args.task == 'NodeTask':
    data, input_dim, output_dim = load4node(args.dataset_name)   
    data =
    if args.prompt_type in ['Gprompt', 'All-in-one', 'GPF', 'GPF-plus']:
        graphs_list = load_induced_graph(args.dataset_name, data, args.device) 
        graphs_list = None 

if args.task == 'GraphTask':
    input_dim, output_dim, dataset = load4graph(args.dataset_name)

Downstream Tasks

In, we designed two tasks (Node Classification, Graph Classification). Here are some examples.

import prompt_graph as ProG
from ProG.tasker import NodeTask, LinkTask, GraphTask

if args.task == 'GraphTask':
    input_dim, output_dim, dataset = load4graph(args.dataset_name)

if args.task == 'NodeTask':
    tasker = NodeTask(pre_train_model_path = args.pre_train_model_path, 
                    dataset_name = args.dataset_name, num_layer = args.num_layer,
                    gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type,
                    epochs = args.epochs, shot_num = args.shot_num, device=args.device, lr =, wd = args.decay,
                    batch_size = args.batch_size, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list)

if args.task == 'GraphTask':
    tasker = GraphTask(pre_train_model_path = args.pre_train_model_path, 
                    dataset_name = args.dataset_name, num_layer = args.num_layer, gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type, epochs = args.epochs,
                    shot_num = args.shot_num, device=args.device, lr =, wd = args.decay,
                    batch_size = args.batch_size, dataset = dataset, input_dim = input_dim, output_dim = output_dim)

_, test_acc, std_test_acc, f1, std_f1, roc, std_roc, _, _=

Kindly note that the comparison takes the same pre-trained pth.The absolute value of performance won't mean much because the final results may vary depending on different pre-training states.It would be more interesting to see the relative performance with other pre-training paradigms.

Bench Random Search

In our bench


Graphs Graph classes Avg. nodes Avg. edges Node features Node classes Task (N/E/G)
Cora 1 2,708 5,429 1,433 7 N
Pubmed 1 19,717 88,648 500 3 N
CiteSeer 1 3,327 9,104 3,703 6 N
Mutag 188 17.9 39.6 ? 7 N
Reddit 1 232,965 23,213,838 602 41 N
Amazon 1 13,752 491,722 767 10 N
Flickr 1 89,250 899,756 500 7 N
PROTEINS 1,113 39.06 72.82 1 3 N, G
ENZYMES 600 32.63 62.14 18 3 N, G

Environment Setup

--Python 3.9.17 

--PyTorch 2.0.1 

--torch-geometric 2.3.1

installation for PYG quick start

pip install torch_geometric

pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f # Optional dependencies

or run this command

conda install pyg -c pyg


Note Current experimental datasets: Node/Edge:Cora/Citeseer/Pubmed; Graph:MUTAG

  • Write a comprehensive usage document(refer to pyG)
  • Write a tutorial, and polish data code, to make our readers feel more easily to deal with their own data. That is to: (1) provide a demo/tutorial to let our readers know how to deal with data; (2) polish data code, making it more robust, reliable, and readable.
  • Pre_train: implementation of DGI. (Deep Graph Infomax), InfoGraph, contextpred, AttrMasking, ContextPred, GraphMAE, GraphLoG, JOAO
  • Debug Gprompt inference, All-in-one Tune,graphcl loss
  • Add Prompt: prodigy (NeurIPS'2023 Spotlight)
  • induced graph(1.better way to generate induced graph/2.simplify the 3 type of generate-func)
  • add prompt type table (prompt_type, prompt paradigm, loss function, task_type)
  • add pre_train type table
  • support deep GNN layers by adding the feature DeepGCNLayer

🌹Please Cite Our Work If Helpful:

Thanks! / 谢谢! / ありがとう! / merci! / 감사! / Danke! / спасибо! / gracias! ...

  title={All in One: Multi-Task Prompting for Graph Neural Networks},
  author={Sun, Xiangguo and Cheng, Hong and Li, Jia and Liu, Bo and Guan, Jihong},
  booktitle={Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery \& data mining (KDD'23)},
  pages = {2120–2131},
  location = {Long Beach, CA, USA},
  isbn = {9798400701030},
  url = {},
  doi = {10.1145/3580305.3599256}

  title = {Graph Prompt Learning: A Comprehensive Survey and Beyond},
  author = {Sun, Xiangguo and Zhang, Jiawen and Wu, Xixi and Cheng, Hong and Xiong, Yun and Li, Jia},
  year = {2023},
  journal = {arXiv:2311.16534},
  eprint = {2311.16534},
  archiveprefix = {arxiv}

      title={All in One and One for All: A Simple yet Effective Method towards Cross-domain Graph Pretraining}, 
      author={Haihong Zhao and Aochuan Chen and Xiangguo Sun and Hong Cheng and Jia Li},
      booktitle={Proceedings of the 27th ACM SIGKDD international conference on knowledge discovery \& data mining (KDD'24)}

  title={Protein Multimer Structure Prediction via {PPI}-guided Prompt Learning},
  author={Ziqi Gao and Xiangguo Sun and Zijing Liu and Yu Li and Hong Cheng and Jia Li},
  booktitle={The Twelfth International Conference on Learning Representations (ICLR)},

      title={Prompt Learning on Temporal Interaction Graphs}, 
      author={Xi Chen and Siwei Zhang and Yun Xiong and Xixi Wu and Jiawei Zhang and Xiangguo Sun and Yao Zhang and Yinglong Zhao and Yulin Kang},
      journal = {arXiv:2402.06326}

      title={A Survey of Graph Meets Large Language Model: Progress and Future Directions}, 
      author={Yuhan Li and Zhixun Li and Peisong Wang and Jia Li and Xiangguo Sun and Hong Cheng and Jeffrey Xu Yu},
      journal = {arXiv:2311.12399}

  title={Advanced Drug Interaction Event Prediction},
  author={Wang, Yingying and Xiong, Yun and Wu, Xixi and Sun, Xiangguo and Zhang, Jiawei},
  journal={arXiv preprint arXiv:2402.11472},

Media Coverage

Media Reports

Online Discussion

Other research papers released by us

Call for Contributors!

Once you are invited as a contributor, you would be asked to follow the following steps:

  • step 1. create a temp branch (e.g. xgTemp) from the main branch (latest branch).
  • step 2. fetch origin/xgTemp to your local xgTemp, and make your own changes via PyCharm etc.
  • step 3. push your changes from local xgTemp to your github cloud branch: origin/xgTemp.
  • step 4. open a pull request to merge from your branch to main.

When you finish all these jobs. I will get a notification and approve merging your branch to main. Once I finish, I will delete your branch, and next time you will repeat the above jobs.

A widely tested main branch will then be merged to the stable branch and a new version will be released based on stable branch.