///
This guide demonstrates how to build and sample a discrete Energy-Based Model (EBM) in THRML, featuring a mix of `SpinNode` (binary) and `CategoricalNode` variables. This type of mixed model is common
781 views
~781 views from guests
Guest views are estimated from total page views. These include anonymous visitors and users who weren't logged in when they viewed the page.
This guide demonstrates how to build and sample a discrete Energy-Based Model (EBM) in THRML, featuring a mix of SpinNode (binary) and CategoricalNode variables. This type of mixed model is common in various domains, from statistical physics to machine learning, and THRML provides the tools to handle such heterogeneous graphical models efficiently.
We will cover:
SpinNode and CategoricalNode instances.DiscreteEBMFactor types for biases and interactions.FactorSamplingProgram with appropriate SpinGibbsConditional and CategoricalGibbsConditional samplers.For a general introduction to THRML and its installation, please refer to the [Getting Started] guide. For detailed API information on the components used here, see the [API Reference: Models & Factors] and [API Reference: Block Management & Sampling] pages.
Consider a small model with two binary spin variables ($s_0, s_1$) and two 3-category variables ($c_0, c_1$). The energy function will include:
Define Node Types and Model Structure:
N_CATEGORIES = 3: We explicitly define the number of states our CategoricalNodes can take. This is crucial for the categorical sampler.spin_nodes and cat_nodes: Instances of SpinNode and CategoricalNode are created. Each SpinNode represents a binary variable (mapped to True for +1 and False for -1 in JAX), and each CategoricalNode represents a variable that can take on values from 0 to N_CATEGORIES-1 (represented as jnp.uint8).node_sd_map = DEFAULT_NODE_SHAPE_DTYPES: This dictionary is vital. It informs THRML about the default JAX ShapeDtypeStruct for each AbstractNode subclass. This allows THRML to allocate and manage memory for states correctly. For SpinNode, it's bool_; for CategoricalNode, it's uint8.Define Factors for the EBM: Factors define the local energy contributions in the EBM. We use different factor classes for different interaction types:
SpinEBMFactor: Used for interactions involving only SpinNodes. Here, it defines biases for spin_nodes. The weights for SpinEBMFactor are scalar values for each node, representing their individual biases. Its shape is (num_spin_nodes,).CategoricalEBMFactor: Used for interactions involving only CategoricalNodes. Here, it defines biases for cat_nodes. The weights for CategoricalEBMFactor are vectors for each node, with one value per category. Its shape is (num_cat_nodes, N_CATEGORIES).DiscreteEBMFactor: This is the key factor for mixed interactions. It takes spin_node_groups and categorical_node_groups separately.
mixed_factor_0 models an interaction s_0 * W[c_0]. spin_node_groups is [Block([spin_nodes[0]])] (one spin node). categorical_node_groups is [Block([cat_nodes[0]])] (one categorical node). The weights tensor mixed_weights_0 has shape (batch_size, N_CATEGORIES), where batch_size corresponds to the number of individual (spin_node, categorical_node) pairs in this factor batch (which is 1 here). The categorical index in W[c_0] selects the specific weight for that categorical state.ebm = FactorizedEBM(...): All factors are aggregated into a FactorizedEBM, which sums their energy contributions to compute the total energy of any given state.Define Sampling Blocks and Program:
free_blocks: For efficient block Gibbs sampling, nodes are partitioned into Blocks. A core rule in THRML is that all nodes within a single Block must be of the same type. Here, we create one Block for all spin_nodes and another Block for all cat_nodes. This allows efficient, vectorized updates within each block.gibbs_spec = BlockGibbsSpec(...): This specification maps our blocks and node types to an internal representation for JAX-compatible operations. It ensures that when a sampler requests a node's state, it knows where to find it and what its shape/dtype should be.samplers: We define a list of conditional samplers, one for each block in free_blocks. The order of samplers must match the order of free_blocks.
SpinGibbsConditional(): Used for Block(spin_nodes). It calculates the local field for each spin and samples from a Bernoulli distribution (mapping True/False to +1/-1).CategoricalGibbsConditional(N_CATEGORIES): Used for Block(cat_nodes). It calculates a parameter vector (logits) for each categorical node and samples from a Softmax distribution. The N_CATEGORIES parameter informs the sampler about the size of the categorical state space.program = FactorSamplingProgram(...): This combines the gibbs_spec, the samplers, and the ebm.factors (which are converted into InteractionGroups internally). This program encapsulates the entire sampling logic.Initialize States and Schedule Sampling:
init_state_free: We provide an initial random state for our free blocks. For SpinNodes, jax.random.bernoulli generates bool_ values. For CategoricalNodes, jax.random.randint generates uint8 values within the specified N_CATEGORIES. The order of states in this list must match the order of free_blocks.schedule = SamplingSchedule(...): Defines how many warm-up steps, total samples, and steps between samples will be performed, optimizing for chain convergence and reducing autocorrelation.Run Sampling:
nodes_to_sample = [Block(all_nodes)]: We tell sample_states to collect the states of all nodes in our model. If we only wanted spin nodes, we could pass [Block(spin_nodes)].samples = sample_states(...): Executes the block Gibbs sampling. The output samples is a list of JAX arrays, where each array corresponds to a Block specified in nodes_to_sample. In this case, samples[0] contains all collected states. The spin states appear first (as bool_), followed by categorical states (as uint8).This example illustrates THRML's flexibility in handling heterogeneous graphical models, allowing you to define complex interactions and sample them efficiently using JAX.