Getting Started with Graph Learning
Welcome to the first chapter of our journey into the world of graph neural networks (GNNs). In this chapter, we will delve into the foundations of GNNs and understand why they are crucial tools in modern data analysis and machine learning. To that end, we will answer three essential questions that will provide us with a comprehensive understanding of GNNs.
First, we will explore the significance of graphs as a representation of data, and why they are widely used in various domains such as computer science, biology, and finance. Next, we will delve into the importance of graph learning, where we will understand the different applications of graph learning and the different families of graph learning techniques. Finally, we will focus on the GNN family, highlighting its unique features, performance, and how it stands out compared to other methods.
By the end of this chapter, you will have a clear understanding of why GNNs are important and how they can be used to solve real-world problems. You will also be equipped with the knowledge and skills you need to dive deeper into more advanced topics. So, let’s get started!
In this chapter, we will cover the following main topics:
- Why graphs?
- Why graph learning?
- Why graph neural networks?
The first question we need to address is: why are we interested in graphs in the first place? Graph theory, the mathematical study of graphs, has emerged as a fundamental tool for understanding complex systems and relationships. A graph is a visual representation of a collection of nodes (also called vertices) and edges that connect these nodes, providing a structure to represent entities and their relationships (see Figure 1.1).
Figure 1.1 – Example of a graph with six nodes and five edges
By representing a complex system as a network of entities with interactions, we can analyze their relationships, allowing us to gain a deeper understanding of their underlying structures and patterns. The versatility of graphs makes them a popular choice in various domains, including the following:
- Computer science, where graphs can be used to model the structure of computer programs, making it easier to understand how different components of a system interact with each other
- Physics, where graphs can be used to model physical systems and their interactions, such as the relationship between particles and their properties
- Biology, where graphs can be used to model biological systems, such as metabolic pathways, as a network of interconnected entities
- Social sciences, where graphs can be used to study and understand complex social networks, including the relationships between individuals in a community
- Finance, where graphs can be used to analyze stock market trends and relationships between different financial instruments
- Engineering, where graphs can be used to model and analyze complex systems, such as transportation networks and electrical power grids
These domains naturally exhibit a relational structure. For instance, graphs are a natural representation of social networks: nodes are users, and edges represent friendships. But graphs are so versatile they can also be applied to domains where the relational structure is less natural, unlocking new insights and understanding.
For example, images can be represented as a graph, as in Figure 1.2. Each pixel is a node, and edges represent relationships between neighboring pixels. This allows for the application of graph-based algorithms to image processing and computer vision tasks.
Figure 1.2 – Left: original image; right: graph representation of this image
Similarly, a sentence can be transformed into a graph, where nodes are words and edges represent relationships between adjacent words. This approach is useful in natural language processing and information retrieval tasks, where the context and meaning of words are critical factors.
Unlike text and images, graphs do not have a fixed structure. However, this flexibility also makes graphs more challenging to handle. The absence of a fixed structure means they can have an arbitrary number of nodes and edges, with no specific ordering. In addition, graphs can represent dynamic data, where the connections between entities can change over time. For example, the relationships between users and products can change as they interact with each other. In this scenario, nodes and edges are updated to reflect changes in the real world, such as new users, new products, and new relationships.
In the next section, we will delve deeper into how to use graphs with machine learning to create valuable applications.
Why graph learning?
Graph learning is the application of machine learning techniques to graph data. This study area encompasses a range of tasks aimed at understanding and manipulating graph-structured data. There are many graphs learning tasks, including the following:
- Node classification is a task that involves predicting the category (class) of a node in a graph. For example, it can categorize online users or items based on their characteristics. In this task, the model is trained on a set of labeled nodes and their attributes, and it uses this information to predict the class of unlabeled nodes.
- Link prediction is a task that involves predicting missing links between pairs of nodes in a graph. This is useful in knowledge graph completion, where the goal is to complete a graph of entities and their relationships. For example, it can be used to predict the relationships between people based on their social network connections (friend recommendation).
- Graph classification is a task that involves categorizing different graphs into predefined categories. One example of this is in molecular biology, where molecular structures can be represented as graphs, and the goal is to predict their properties for drug design. In this task, the model is trained on a set of labeled graphs and their attributes, and it uses this information to categorize unseen graphs.
- Graph generation is a task that involves generating new graphs based on a set of desired properties. One of the main applications is generating novel molecular structures for drug discovery. This is achieved by training a model on a set of existing molecular structures and then using it to generate new, unseen structures. The generated structures can be evaluated for their potential as drug candidates and further studied.
Graph learning has many other practical applications that can have a significant impact. One of the most well-known applications is recommender systems, where graph learning algorithms recommend relevant items to users based on their previous interactions and relationships with other items. Another important application is traffic forecasting, where graph learning can improve travel time predictions by considering the complex relationships between different routes and modes of transportation.
The versatility and potential of graph learning make it an exciting field of research and development. The study of graphs has advanced rapidly in recent years, driven by the availability of large datasets, powerful computing resources, and advancements in machine learning and artificial intelligence. As a result, we can list four prominent families of graph learning techniques :
- Graph signal processing, which applies traditional signal processing methods to graphs, such as the graph Fourier transform and spectral analysis. These techniques reveal the intrinsic properties of the graph, such as its connectivity and structure.
- Matrix factorization, which seeks to find low-dimensional representations of large matrices. The goal of matrix factorization is to identify latent factors or patterns that explain the observed relationships in the original matrix. This approach can provide a compact and interpretable representation of the data.
- Random walk, which refers to a mathematical concept used to model the movement of entities in a graph. By simulating random walks over a graph, information about the relationships between nodes can be gathered. This is why they are often used to generate training data for machine learning models.
- Deep learning, which is a subfield of machine learning that focuses on neural networks with multiple layers. Deep learning methods can effectively encode and represent graph data as vectors. These vectors can then be used in various tasks with remarkable performance.
It is important to note that these techniques are not mutually exclusive and often overlap in their applications. In practice, they are often combined to form hybrid models that leverage the strengths of each. For example, matrix factorization and deep learning techniques might be used in combination to learn low-dimensional representations of graph-structured data.
As we delve into the world of graph learning, it is crucial to understand the fundamental building block of any machine learning technique: the dataset. Traditional tabular datasets, such as spreadsheets, represent data as rows and columns with each row representing a single data point. However, in many real-world scenarios, the relationships between data points are just as meaningful as the data points themselves. This is where graph datasets come in. Graph datasets represent data points as nodes in a graph and the relationships between those data points as edges.
Let’s take the tabular dataset shown in Figure 1.3 as an example.
Figure 1.3 – Family tree as a tabular dataset versus a graph dataset
This dataset represents information about five members of a family. Each member has three features (or attributes): name, age, and gender. However, the tabular version of this dataset doesn’t show the connections between these people. On the contrary, the graph version represents them with edges, which allows us to understand the relationships in this family. In many contexts, the connections between nodes are crucial in understanding the data, which is why representing data in graph form is becoming increasingly popular.
Now that we have a basic understanding of graph machine learning and the different types of tasks it involves, we can move on to exploring one of the most important approaches for solving these tasks: graph neural networks.
Why graph neural networks?
In this book, we will focus on the deep learning family of graph learning techniques, often referred to as graph neural networks. GNNs are a new category of deep learning architecture and are specifically designed for graph-structured data. Unlike traditional deep learning algorithms, which have been primarily developed for text and images, GNNs are explicitly made to process and analyze graph datasets (see Figure 1.4).
Figure 1.4 – High-level architecture of a GNN pipeline, with a graph as input and an output that corresponds to a given task
GNNs have emerged as a powerful tool for graph learning and have shown excellent results in various tasks and industries. One of the most striking examples is how a GNN model identified a new antibiotic . The model was trained on 2,500 molecules and was tested on a library of 6,000 compounds. It predicted that a molecule called halicin should be able to kill many antibiotic-resistant bacteria while having low toxicity to human cells. Based on this prediction, the researchers used halicin to treat mice infected with antibiotic-resistant bacteria. They demonstrated its effectiveness and believe the model could be used to design new drugs.
How do GNNs work? Let’s take the example of a node classification task in a social network, like the previous family tree (Figure 1.3). In a node classification task, GNNs take advantage of information from different sources to create a vector representation of each node in the graph. This representation encompasses not only the original node features (such as name, age, and gender) but also information from edge features (such as the strength of relationships between nodes) and global features (such as network-wide statistics).
This is why GNNs are more efficient than traditional machine learning techniques on graphs. Instead of being limited to the original attributes, GNNs enrich the original node features with attributes from neighboring nodes, edges, and global features, making the representation much more comprehensive and meaningful. The new node representations are then used to perform a specific task, such as node classification, regression, or link prediction.
Specifically, GNNs define a graph convolution operation that aggregates information from the neighboring nodes and edges to update the node representation. This operation is performed iteratively, allowing the model to learn more complex relationships between nodes as the number of iterations increases. For example, Figure 1.5 shows how a GNN would calculate the representation of node 5 using neighboring nodes.
Figure 1.5 – Left: input graph; right: computation graph representing how a GNN computes the representation of node 5 based on its neighbors
It is worth noting that Figure 1.5 provides a simplified illustration of a computation graph. In reality, there are various kinds of GNNs and GNN layers, each of which has a unique structure and way of aggregating information from neighboring nodes. These different variants of GNNs also have their own advantages and limitations and are well-suited for specific types of graph data and tasks. When selecting the appropriate GNN architecture for a particular problem, it is crucial to understand the characteristics of the graph data and the desired outcome.
More generally, GNNs, like other deep learning techniques, are most effective when applied to specific problems. These problems are characterized by high complexity, meaning that learning good representations is critical to solving the task at hand. For example, a highly complex task could be recommending the right products among billions of options to millions of customers. On the other hand, some problems, such as finding the youngest member of our family tree, can be solved without any machine learning technique.
Furthermore, GNNs require a substantial amount of data to perform effectively. Traditional machine learning techniques might be a better fit in cases where the dataset is small, as they are less reliant on large amounts of data. However, these techniques do not scale as well as GNNs. GNNs can process bigger datasets thanks to parallel and distributed training. They can also exploit the additional information more efficiently, which produces better results.
In this chapter, we answered three main questions: why graphs, why graph learning, and why graph neural networks? First, we explored the versatility of graphs in representing various data types, such as social networks and transportation networks, but also text and images. We discussed the different applications of graph learning, including node classification and graph classification, and highlighted the four main families of graph learning techniques. Finally, we emphasized the significance of GNNs and their superiority over other techniques, especially regarding large, complex datasets. By answering these three main questions, we aimed to provide a comprehensive overview of the importance of GNNs and why they are becoming vital tools in machine learning.
In Chapter 2, Graph Theory for Graph Neural Networks, we will dive deeper into the basics of graph theory, which provides the foundation for understanding GNNs. This chapter will cover the fundamental concepts of graph theory, including concepts such as adjacency matrices and degrees. Additionally, we will delve into the different types of graphs and their applications, such as directed and undirected graphs, and weighted and unweighted graphs.
-  F. Xia et al., Graph Learning: A Survey, IEEE Transactions on Artificial Intelligence, vol. 2, no. 2, pp. 109–127, Apr. 2021, DOI: 10.1109/tai.2021.3076021. Available at https://arxiv.org/abs/2105.00696
-  A. Trafton, Artificial intelligence yields new antibiotic, MIT News, 20-Feb-2020. [Online]. Available at https://news.mit.edu/2020/artificial-intelligence-identifies-new-antibiotic-0220