///
This guide will walk you through the installation of the THRML library and provide a detailed explanation of a quick example demonstrating its core functionalities. THRML is designed to facilitate the
1181 views
~1181 views from guests
Guest views are estimated from total page views. These include anonymous visitors and users who weren't logged in when they viewed the page.
This guide will walk you through the installation of the THRML library and provide a detailed explanation of a quick example demonstrating its core functionalities. THRML is designed to facilitate the building and sampling of probabilistic graphical models (PGMs) with a focus on efficiency and compatibility with advanced hardware. For an overview of THRML's capabilities and architecture, refer to the [THRML: Thermodynamic Hypergraphical Model Library] page.
THRML requires Python 3.10 or newer. You can install it using pip or uv:
or
This example demonstrates how to set up and sample a small Ising chain using THRML's block Gibbs sampling. We will create a simple 1D Ising model and sample its states efficiently.
Define Nodes and Edges:
nodes = [SpinNode() for _ in range(5)]: We create five SpinNode instances. In THRML, SpinNode represents a binary random variable that can take on values of -1 or +1.edges = [(nodes[i], nodes[i+1]) for i in range(4)]: We define a simple linear chain where each node is connected to its immediate neighbor, forming a 1D Ising chain.Define Model Parameters:
biases = jnp.zeros((5,)): An array of zeros is created to represent the external magnetic field (biases) for each of the five nodes.weights = jnp.ones((4,)) * 0.5: An array of ones multiplied by 0.5 represents the coupling strengths (weights) between connected nodes. In an Ising model, these weights define the interaction energy between spins.beta = jnp.array(1.0): This is the inverse temperature parameter, often used to scale the energy function.model = IsingEBM(nodes, edges, biases, weights, beta): An IsingEBM (Energy-Based Model) is instantiated using the defined nodes, edges, biases, weights, and beta. This object encapsulates the thermodynamic properties of our Ising chain.Define Sampling Blocks:
free_blocks = [Block(nodes[::2]), Block(nodes[1::2])]: For efficient block Gibbs sampling, the nodes are partitioned into "free blocks." Here, we create two blocks: one containing nodes at even indices (nodes[0], nodes[2], nodes[4]) and another with nodes at odd indices (nodes[1], nodes[3]). This is a common strategy for 2-color (or bipartite) graphs, allowing nodes within a block to be sampled in parallel while keeping others fixed.program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[]): An IsingSamplingProgram is set up. This program ties the IsingEBM to the chosen free_blocks. clamped_blocks=[] indicates that no nodes are held fixed during the sampling process.Initialize State and Schedule Sampling:
key = jax.random.key(0): A JAX random key is initialized for reproducibility of random operations.k_init, k_samp = jax.random.split(key, 2): The master key is split into two subkeys: one for initializing the state and one for the main sampling loop.init_state = hinton_init(k_init, model, free_blocks, ()): The initial state of the free blocks is generated. hinton_init is a heuristic that initializes spins probabilistically based on their local biases and the beta parameter. The () indicates that there's no batch dimension for this initial state.schedule = SamplingSchedule(...): A SamplingSchedule object is created to control the sampling process:
n_warmup=100: The first 100 Gibbs steps are considered "warm-up" and are discarded to allow the system to reach equilibrium.n_samples=1000: After warm-up, 1000 samples of the system's state will be collected.steps_per_sample=2: Between each collected sample, the sampler performs 2 additional Gibbs steps. This helps reduce autocorrelation between consecutive samples, providing more independent data.Run Sampling:
samples = sample_states(...): This is the main function call that executes the block Gibbs sampling. It takes the sampling key, program, schedule, initial free state, an empty list for clamped states, and a list of blocks whose states should be observed and returned. In this case, [Block(nodes)] requests the states of all nodes in the model. The output samples will be a JAX array with dimensions (n_samples, total_nodes), containing the collected states of the Ising chain.