torch_geometric.loader — pytorch_geometric documentation (original) (raw)

DataLoader A data loader which merges data objects from a torch_geometric.data.Dataset to a mini-batch.
NodeLoader A data loader that performs mini-batch sampling from node information, using a generic BaseSampler implementation that defines a sample_from_nodes() function and is supported on the provided input data object.
LinkLoader A data loader that performs mini-batch sampling from link information, using a generic BaseSampler implementation that defines a sample_from_edges() function and is supported on the provided input data object.
NeighborLoader A data loader that performs neighbor sampling as introduced in the "Inductive Representation Learning on Large Graphs" paper.
LinkNeighborLoader A link-based data loader derived as an extension of the node-based torch_geometric.loader.NeighborLoader.
HGTLoader The Heterogeneous Graph Sampler from the "Heterogeneous Graph Transformer" paper.
ClusterData Clusters/partitions a graph data object into multiple subgraphs, as motivated by the "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" paper.
ClusterLoader The data loader scheme from the "Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks" paper which merges partioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch.
GraphSAINTSampler The GraphSAINT sampler base class from the "GraphSAINT: Graph Sampling Based Inductive Learning Method" paper.
GraphSAINTNodeSampler The GraphSAINT node sampler class (see GraphSAINTSampler).
GraphSAINTEdgeSampler The GraphSAINT edge sampler class (see GraphSAINTSampler).
GraphSAINTRandomWalkSampler The GraphSAINT random walk sampler class (see GraphSAINTSampler).
ShaDowKHopSampler The ShaDow \(k\)-hop sampler from the "Decoupling the Depth and Scope of Graph Neural Networks" paper.
RandomNodeLoader A data loader that randomly samples nodes within a graph and returns their induced subgraph.
ZipLoader A loader that returns a tuple of data objects by sampling from multiple NodeLoader or LinkLoader instances.
DataListLoader A data loader which batches data objects from a torch_geometric.data.dataset to a Python list.
DenseDataLoader A data loader which batches data objects from a torch_geometric.data.dataset to a torch_geometric.data.Batch object by stacking all attributes in a new dimension.
TemporalDataLoader A data loader which merges succesive events of a torch_geometric.data.TemporalData to a mini-batch.
NeighborSampler The neighbor sampler from the "Inductive Representation Learning on Large Graphs" paper, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.
ImbalancedSampler A weighted random sampler that randomly samples elements according to class distribution.
DynamicBatchSampler Dynamically adds samples to a mini-batch up to a maximum size (either based on number of nodes or number of edges).
PrefetchLoader A GPU prefetcher class for asynchronously transferring data of a torch.utils.data.DataLoader from host memory to device memory.
CachedLoader A loader to cache mini-batch outputs, e.g., obtained during NeighborLoader iterations.
AffinityMixin A context manager to enable CPU affinity for data loader workers (only used when running on CPU devices).
RAGQueryLoader Loader meant for making RAG queries from a remote backend.
RAGFeatureStore Feature store template for remote GNN RAG backend.
RAGGraphStore Graph store template for remote GNN RAG backend.

class DataLoader(dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter], batch_size: int = 1, shuffle: bool = False, follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None, **kwargs)[source]

A data loader which merges data objects from atorch_geometric.data.Dataset to a mini-batch. Data objects can be either of type Data orHeteroData.

Parameters:

class NodeLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], node_sampler: BaseSampler, input_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_time: Optional[Tensor] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, custom_cls: Optional[HeteroData] = None, input_id: Optional[Tensor] = None, **kwargs)[source]

A data loader that performs mini-batch sampling from node information, using a generic BaseSamplerimplementation that defines asample_from_nodes() function and is supported on the provided input data object.

Parameters:

collate_fn(index: Union[Tensor, List[int]]) → Any[source]

Samples a subgraph from a batch of input nodes.

Return type:

Any

filter_fn(out: Union[SamplerOutput, HeteroSamplerOutput]) → Union[Data, HeteroData][source]

Joins the sampled nodes with their corresponding features, returning the resulting Data orHeteroData object to be used downstream.

Return type:

Union[Data, HeteroData]

class LinkLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], link_sampler: BaseSampler, edge_label_index: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, edge_label: Optional[Tensor] = None, edge_label_time: Optional[Tensor] = None, neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, custom_cls: Optional[HeteroData] = None, input_id: Optional[Tensor] = None, **kwargs)[source]

A data loader that performs mini-batch sampling from link information, using a generic BaseSamplerimplementation that defines asample_from_edges() function and is supported on the provided input data object.

Note

Negative sampling is currently implemented in an approximate way, i.e. negative edges may contain false negatives.

Parameters:

collate_fn(index: Union[Tensor, List[int]]) → Any[source]

Samples a subgraph from a batch of input edges.

Return type:

Any

filter_fn(out: Union[SamplerOutput, HeteroSamplerOutput]) → Union[Data, HeteroData][source]

Joins the sampled nodes with their corresponding features, returning the resulting Data orHeteroData object to be used downstream.

Return type:

Union[Data, HeteroData]

class NeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], input_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_time: Optional[Tensor] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, weight_attr: Optional[str] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, is_sorted: bool = False, filter_per_worker: Optional[bool] = None, neighbor_sampler: Optional[NeighborSampler] = None, directed: bool = True, **kwargs)[source]

A data loader that performs neighbor sampling as introduced in the“Inductive Representation Learning on Large Graphs” paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

More specifically, num_neighbors denotes how many neighbors are sampled for each node in each iteration.NeighborLoader takes in this list ofnum_neighbors and iteratively samples num_neighbors[i] for each node involved in iteration i - 1.

Sampled nodes are sorted based on the order in which they were sampled. In particular, the first batch_size nodes represent the set of original mini-batch nodes.

from torch_geometric.datasets import Planetoid from torch_geometric.loader import NeighborLoader

data = Planetoid(path, name='Cora')[0]

loader = NeighborLoader( data, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, input_nodes=data.train_mask, )

sampled_data = next(iter(loader)) print(sampled_data.batch_size)

128

By default, the data loader will only include the edges that were originally sampled (directed = True). This option should only be used in case the number of hops is equivalent to the number of GNN layers. In case the number of GNN layers is greater than the number of hops, consider setting directed = False, which will include all edges between all sampled nodes (but is slightly slower as a result).

Furthermore, NeighborLoader works for bothhomogeneous graphs stored via Data as well as heterogeneous graphs stored viaHeteroData. When operating in heterogeneous graphs, up to num_neighborsneighbors will be sampled for each edge_type. However, more fine-grained control over the amount of sampled neighbors of individual edge types is possible:

from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader

hetero_data = OGB_MAG(path)[0]

loader = NeighborLoader( hetero_data, # Sample 30 neighbors for each node and edge type for 2 iterations num_neighbors={key: [30] * 2 for key in hetero_data.edge_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_data['paper'].train_mask), )

sampled_hetero_data = next(iter(loader)) print(sampled_hetero_data['paper'].batch_size)

128

The NeighborLoader will return subgraphs where global node indices are mapped to local indices corresponding to this specific subgraph. However, often times it is desired to map the nodes of the current subgraph back to the global node indices. TheNeighborLoader will include this mapping as part of the data object:

loader = NeighborLoader(data, ...) sampled_data = next(iter(loader)) print(sampled_data.n_id) # Global node index of each node in batch.

In particular, the data loader will add the following attributes to the returned mini-batch:

Parameters:

class LinkNeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], edge_label_index: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, edge_label: Optional[Tensor] = None, edge_label_time: Optional[Tensor] = None, replace: bool = False, subgraph_type: Union[SubgraphType, str] = 'directional', disjoint: bool = False, temporal_strategy: str = 'uniform', neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, time_attr: Optional[str] = None, weight_attr: Optional[str] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, is_sorted: bool = False, filter_per_worker: Optional[bool] = None, neighbor_sampler: Optional[NeighborSampler] = None, directed: bool = True, **kwargs)[source]

A link-based data loader derived as an extension of the node-basedtorch_geometric.loader.NeighborLoader. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

More specifically, this loader first selects a sample of edges from the set of input edges edge_label_index (which may or not be edges in the original graph) and then constructs a subgraph from all the nodes present in this list by sampling num_neighbors neighbors in each iteration.

from torch_geometric.datasets import Planetoid from torch_geometric.loader import LinkNeighborLoader

data = Planetoid(path, name='Cora')[0]

loader = LinkNeighborLoader( data, # Sample 30 neighbors for each node for 2 iterations num_neighbors=[30] * 2, # Use a batch size of 128 for sampling training nodes batch_size=128, edge_label_index=data.edge_index, )

sampled_data = next(iter(loader)) print(sampled_data)

Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128])

It is additionally possible to provide edge labels for sampled edges, which are then added to the batch:

loader = LinkNeighborLoader( data, num_neighbors=[30] * 2, batch_size=128, edge_label_index=data.edge_index, edge_label=torch.ones(data.edge_index.size(1)) )

sampled_data = next(iter(loader)) print(sampled_data)

Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368], train_mask=[1368], val_mask=[1368], test_mask=[1368], edge_label_index=[2, 128], edge_label=[128])

The rest of the functionality mirrors that ofNeighborLoader, including support for heterogeneous graphs. In particular, the data loader will add the following attributes to the returned mini-batch:

Note

Negative sampling is currently implemented in an approximate way, i.e. negative edges may contain false negatives.

Warning

Note that the sampling scheme is independent from the edge we are making a prediction for. That is, by default supervision edges in edge_label_index will not get masked out during sampling. In case there exists an overlap between message passing edges indata.edge_index and supervision edges inedge_label_index, you might end up sampling an edge you are making a prediction for. You can generally avoid this behavior (if desired) by makingdata.edge_index and edge_label_index two disjoint sets of edges, e.g., via theRandomLinkSplit transformation and its disjoint_train_ratio argument.

Parameters:

class HGTLoader(data: Union[HeteroData, Tuple[FeatureStore, GraphStore]], num_samples: Union[List[int], Dict[str, List[int]]], input_nodes: Union[str, Tuple[str, Optional[Tensor]]], is_sorted: bool = False, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, filter_per_worker: Optional[bool] = None, **kwargs)[source]

The Heterogeneous Graph Sampler from the “Heterogeneous Graph Transformer” paper. This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

HGTLoader tries to (1) keep a similar number of nodes and edges for each type and (2) keep the sampled sub-graph dense to minimize the information loss and reduce the sample variance.

Methodically, HGTLoader keeps track of a node budget for each node type, which is then used to determine the sampling probability of a node. In particular, the probability of sampling a node is determined by the number of connections to already sampled nodes and their node degrees. With this, HGTLoader will sample a fixed amount of neighbors for each node type in each iteration, as given by thenum_samples argument.

Sampled nodes are sorted based on the order in which they were sampled. In particular, the first batch_size nodes represent the set of original mini-batch nodes.

from torch_geometric.loader import HGTLoader from torch_geometric.datasets import OGB_MAG

hetero_data = OGB_MAG(path)[0]

loader = HGTLoader( hetero_data, # Sample 512 nodes per type and per iteration for 4 iterations num_samples={key: [512] * 4 for key in hetero_data.node_types}, # Use a batch size of 128 for sampling training nodes of type paper batch_size=128, input_nodes=('paper', hetero_data['paper'].train_mask), )

sampled_hetero_data = next(iter(loader)) print(sampled_data.batch_size)

128

Parameters:

class ClusterData(data, num_parts: int, recursive: bool = False, save_dir: Optional[str] = None, filename: Optional[str] = None, log: bool = True, keep_inter_cluster_edges: bool = False, sparse_format: Literal['csr', 'csc'] = 'csr')[source]

Clusters/partitions a graph data object into multiple subgraphs, as motivated by the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper.

Note

The underlying METIS algorithm requires undirected graphs as input.

Parameters:

class ClusterLoader(cluster_data, **kwargs)[source]

The data loader scheme from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper which merges partioned subgraphs and their between-cluster links from a large-scale graph data object to form a mini-batch.

Parameters:

class GraphSAINTSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT sampler base class from the “GraphSAINT: Graph Sampling Based Inductive Learning Method” paper. Given a graph in a data object, this class samples nodes and constructs subgraphs that can be processed in a mini-batch fashion. Normalization coefficients for each mini-batch are given vianode_norm and edge_norm data attributes.

Parameters:

class GraphSAINTNodeSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT node sampler class (seeGraphSAINTSampler).

class GraphSAINTEdgeSampler(data, batch_size: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT edge sampler class (seeGraphSAINTSampler).

class GraphSAINTRandomWalkSampler(data, batch_size: int, walk_length: int, num_steps: int = 1, sample_coverage: int = 0, save_dir: Optional[str] = None, log: bool = True, **kwargs)[source]

The GraphSAINT random walk sampler class (seeGraphSAINTSampler).

Parameters:

walk_length (int) – The length of each random walk.

class ShaDowKHopSampler(data: Data, depth: int, num_neighbors: int, node_idx: Optional[Tensor] = None, replace: bool = False, **kwargs)[source]

The ShaDow \(k\)-hop sampler from the “Decoupling the Depth and Scope of Graph Neural Networks” paper. Given a graph in a data object, the sampler will create shallow, localized subgraphs. A deep GNN on this local graph then smooths the informative local signals.

Parameters:

class RandomNodeLoader(data: Union[Data, HeteroData], num_parts: int, **kwargs)[source]

A data loader that randomly samples nodes within a graph and returns their induced subgraph.

Parameters:

class ZipLoader(loaders: Union[List[NodeLoader], List[LinkLoader]], filter_per_worker: Optional[bool] = None, **kwargs)[source]

A loader that returns a tuple of data objects by sampling from multipleNodeLoader or LinkLoader instances.

Parameters:

class DataListLoader(dataset: Union[Dataset, List[BaseData]], batch_size: int = 1, shuffle: bool = False, **kwargs)[source]

A data loader which batches data objects from atorch_geometric.data.dataset to a Python list. Data objects can be either of type Data orHeteroData.

Note

This data loader should be used for multi-GPU support viatorch_geometric.nn.DataParallel.

Parameters:

class DenseDataLoader(dataset: Union[Dataset, List[Data]], batch_size: int = 1, shuffle: bool = False, **kwargs)[source]

A data loader which batches data objects from atorch_geometric.data.dataset to atorch_geometric.data.Batch object by stacking all attributes in a new dimension.

Note

To make use of this data loader, all graph attributes in the dataset need to have the same shape. In particular, this data loader should only be used when working with_dense_ adjacency matrices.

Parameters:

class TemporalDataLoader(data: TemporalData, batch_size: int = 1, neg_sampling_ratio: float = 0.0, **kwargs)[source]

A data loader which merges succesive events of atorch_geometric.data.TemporalData to a mini-batch.

Parameters:

class NeighborSampler(edge_index: Union[Tensor, SparseTensor], sizes: List[int], node_idx: Optional[Tensor] = None, num_nodes: Optional[int] = None, return_e_id: bool = True, transform: Optional[Callable] = None, **kwargs)[source]

The neighbor sampler from the “Inductive Representation Learning on Large Graphs” paper, which allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible.

Given a GNN with \(L\) layers and a specific mini-batch of nodesnode_idx for which we want to compute embeddings, this module iteratively samples neighbors and constructs bipartite graphs that simulate the actual computation flow of GNNs.

More specifically, sizes denotes how much neighbors we want to sample for each node in each layer. This module then takes in these sizes and iteratively samplessizes[l] for each node involved in layer l. In the next layer, sampling is repeated for the union of nodes that were already encountered. The actual computation graphs are then returned in reverse-mode, meaning that we pass messages from a larger set of nodes to a smaller one, until we reach the nodes for which we originally wanted to compute embeddings.

Hence, an item returned by NeighborSampler holds the currentbatch_size, the IDs n_id of all nodes involved in the computation, and a list of bipartite graph objects via the tuple(edge_index, e_id, size), where edge_index represents the bipartite edges between source and target nodes, e_id denotes the IDs of original edges in the full graph, and size holds the shape of the bipartite graph. For each bipartite graph, target nodes are also included at the beginning of the list of source nodes so that one can easily apply skip-connections or add self-loops.

Warning

NeighborSampler is deprecated and will be removed in a future release. Use torch_geometric.loader.NeighborLoader instead.

Parameters:

class ImbalancedSampler(dataset: Union[Dataset, Data, List[Data], Tensor], input_nodes: Optional[Tensor] = None, num_samples: Optional[int] = None)[source]

A weighted random sampler that randomly samples elements according to class distribution. As such, it will either remove samples from the majority class (under-sampling) or add more examples from the minority class (over-sampling).

Graph-level sampling:

from torch_geometric.loader import DataLoader, ImbalancedSampler

sampler = ImbalancedSampler(dataset) loader = DataLoader(dataset, batch_size=64, sampler=sampler, ...)

Node-level sampling:

from torch_geometric.loader import NeighborLoader, ImbalancedSampler

sampler = ImbalancedSampler(data, input_nodes=data.train_mask) loader = NeighborLoader(data, input_nodes=data.train_mask, batch_size=64, num_neighbors=[-1, -1], sampler=sampler, ...)

You can also pass in the class labels directly as a torch.Tensor:

from torch_geometric.loader import NeighborLoader, ImbalancedSampler

sampler = ImbalancedSampler(data.y) loader = NeighborLoader(data, input_nodes=data.train_mask, batch_size=64, num_neighbors=[-1, -1], sampler=sampler, ...)

Parameters:

class DynamicBatchSampler(dataset: Dataset, max_num: int, mode: str = 'node', shuffle: bool = False, skip_too_big: bool = False, num_steps: Optional[int] = None)[source]

Dynamically adds samples to a mini-batch up to a maximum size (either based on number of nodes or number of edges). When data samples have a wide range in sizes, specifying a mini-batch size in terms of number of samples is not ideal and can cause CUDA OOM errors.

Within the DynamicBatchSampler, the number of steps per epoch is ambiguous, depending on the order of the samples. By default the__len__() will be undefined. This is fine for most cases but progress bars will be infinite. Alternatively, num_steps can be supplied to cap the number of mini-batches produced by the sampler.

from torch_geometric.loader import DataLoader, DynamicBatchSampler

sampler = DynamicBatchSampler(dataset, max_num=10000, mode="node") loader = DataLoader(dataset, batch_sampler=sampler, ...)

Parameters:

class PrefetchLoader(loader: DataLoader, device: Optional[device] = None)[source]

A GPU prefetcher class for asynchronously transferring data of atorch.utils.data.DataLoader from host memory to device memory.

Parameters:

class CachedLoader(loader: DataLoader, device: Optional[device] = None, transform: Optional[Callable] = None)[source]

A loader to cache mini-batch outputs, e.g., obtained duringNeighborLoader iterations.

Parameters:

clear()[source]

Clears the cache.

class AffinityMixin[source]

A context manager to enable CPU affinity for data loader workers (only used when running on CPU devices).

Affinitization places data loader workers threads on specific CPU cores. In effect, it allows for more efficient local memory allocation and reduces remote memory calls. Every time a process or thread moves from one core to another, registers and caches need to be flushed and reloaded. This can become very costly if it happens often, and our threads may also no longer be close to their data, or be able to share data in a cache.

See here for the accompanying tutorial.

Warning

To correctly affinitize compute threads (i.e. withKMP_AFFINITY), please make sure that you excludeloader_cores from the list of cores available for the main process. This will cause core oversubsription and exacerbate performance.

loader = NeigborLoader(data, num_workers=3) with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]): for batch in loader: pass

enable_cpu_affinity(loader_cores: Optional[Union[List[List[int]], List[int]]] = None) → None[source]

Enables CPU affinity.

Parameters:

loader_cores (_[_int] , optional) – List of CPU cores to which data loader workers should affinitize to. By default, it will affinitize to numa0 cores. If used with "spawn" multiprocessing context, it will automatically enable multithreading and use multiple cores per each worker.

Return type:

None

class RAGQueryLoader(data: Tuple[RAGFeatureStore, RAGGraphStore], local_filter: Optional[Callable[[Data, Any], Data]] = None, seed_nodes_kwargs: Optional[Dict[str, Any]] = None, seed_edges_kwargs: Optional[Dict[str, Any]] = None, sampler_kwargs: Optional[Dict[str, Any]] = None, loader_kwargs: Optional[Dict[str, Any]] = None)[source]

Loader meant for making RAG queries from a remote backend.

query(query: Any) → Data[source]

Retrieve a subgraph associated with the query with all its feature attributes.

Return type:

Data

class RAGFeatureStore(*args, **kwargs)[source]

Feature store template for remote GNN RAG backend.

abstract retrieve_seed_nodes(query: Any, **kwargs) → Union[Tensor, None, str, Tuple[str, Optional[Tensor]]][source]

Makes a comparison between the query and all the nodes to get all the closest nodes. Return the indices of the nodes that are to be seeds for the RAG Sampler.

Return type:

Union[Tensor, None, str, Tuple[str, Optional[Tensor]]]

abstract retrieve_seed_edges(query: Any, **kwargs) → Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]][source]

Makes a comparison between the query and all the edges to get all the closest nodes. Returns the edge indices that are to be the seeds for the RAG Sampler.

Return type:

Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]]

abstract load_subgraph(sample: Union[SamplerOutput, HeteroSamplerOutput]) → Union[Data, HeteroData][source]

Combines sampled subgraph output with features in a Data object.

Return type:

Union[Data, HeteroData]

class RAGGraphStore(*args, **kwargs)[source]

Graph store template for remote GNN RAG backend.

abstract sample_subgraph(seed_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]], seed_edges: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]], **kwargs) → Union[SamplerOutput, HeteroSamplerOutput][source]

Sample a subgraph using the seeded nodes and edges.

Return type:

Union[SamplerOutput, HeteroSamplerOutput]

abstract register_feature_store(feature_store: FeatureStore)[source]

Register a feature store to be used with the sampler. Samplers need info from the feature store in order to work properly on HeteroGraphs.