This guide will walk you through the installation of the THRML library and provide a detailed explanation of a quick example demonstrating its core functionalities. THRML is designed to facilitate the
This guide will walk you through the installation of the THRML library and provide a detailed explanation of a quick example demonstrating its core functionalities. THRML is designed to facilitate the building and sampling of probabilistic graphical models (PGMs) with a focus on efficiency and compatibility with advanced hardware. For an overview of THRML's capabilities and architecture, refer to the [THRML: Thermodynamic Hypergraphical Model Library] page.
THRML requires Python 3.10 or newer. You can install it using pip or uv:
or
This example demonstrates how to set up and sample a small Ising chain using THRML's block Gibbs sampling. We will create a simple 1D Ising model and sample its states efficiently.
Define Nodes and Edges:
nodes = [SpinNode() for _ in range(5)]: We create five SpinNode instances. In THRML, SpinNode represents a binary random variable that can take on values of -1 or +1.edges = [(nodes[i], nodes[i+1]) for i in range(4)]: We define a simple linear chain where each node is connected to its immediate neighbor, forming a 1D Ising chain.Define Model Parameters:
biases = jnp.zeros((5,)): An array of zeros is created to represent the external magnetic field (biases) for each of the five nodes.weights = jnp.ones((4,)) * 0.5: An array of ones multiplied by 0.5 represents the coupling strengths (weights) between connected nodes. In an Ising model, these weights define the interaction energy between spins.beta = jnp.array(1.0): This is the inverse temperature parameter, often used to scale the energy function.model = IsingEBM(nodes, edges, biases, weights, beta): An IsingEBM (Energy-Based Model) is instantiated using the defined nodes, edges, biases, weights, and beta. This object encapsulates the thermodynamic properties of our Ising chain.Define Sampling Blocks:
free_blocks = [Block(nodes[::2]), Block(nodes[1::2])]: For efficient block Gibbs sampling, the nodes are partitioned into "free blocks." Here, we create two blocks: one containing nodes at even indices (nodes[0], nodes[2], nodes[4]) and another with nodes at odd indices (nodes[1], nodes[3]). This is a common strategy for 2-color (or bipartite) graphs, allowing nodes within a block to be sampled in parallel while keeping others fixed.program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[]): An IsingSamplingProgram is set up. This program ties the IsingEBM to the chosen free_blocks. clamped_blocks=[] indicates that no nodes are held fixed during the sampling process.Initialize State and Schedule Sampling:
key = jax.random.key(0): A JAX random key is initialized for reproducibility of random operations.k_init, k_samp = jax.random.split(key, 2): The master key is split into two subkeys: one for initializing the state and one for the main sampling loop.init_state = hinton_init(k_init, model, free_blocks, ()): The initial state of the free blocks is generated. hinton_init is a heuristic that initializes spins probabilistically based on their local biases and the beta parameter. The () indicates that there's no batch dimension for this initial state.schedule = SamplingSchedule(...): A SamplingSchedule object is created to control the sampling process:
n_warmup=100: The first 100 Gibbs steps are considered "warm-up" and are discarded to allow the system to reach equilibrium.n_samples=1000: After warm-up, 1000 samples of the system's state will be collected.steps_per_sample=2: Between each collected sample, the sampler performs 2 additional Gibbs steps. This helps reduce autocorrelation between consecutive samples, providing more independent data.Run Sampling:
samples = sample_states(...): This is the main function call that executes the block Gibbs sampling. It takes the sampling key, program, schedule, initial free state, an empty list for clamped states, and a list of blocks whose states should be observed and returned. In this case, [Block(nodes)] requests the states of all nodes in the model. The output samples will be a JAX array with dimensions (n_samples, total_nodes), containing the collected states of the Ising chain.