///
This page details the core classes and functions responsible for managing blocks of nodes and orchestrating the sampling process in THRML. These components are fundamental to how THRML handles probabi
252 views
~252 views from guests
Guest views are estimated from total page views. These include anonymous visitors and users who weren't logged in when they viewed the page.
This page details the core classes and functions responsible for managing blocks of nodes and orchestrating the sampling process in THRML. These components are fundamental to how THRML handles probabilistic graphical models, enabling efficient block Gibbs sampling and GPU-accelerated operations.
class Block(Generic[_Node])
A Block is the basic unit through which Gibbs sampling can operate.
Each block represents a collection of nodes that can efficiently be sampled simultaneously in a JAX-friendly SIMD manner. In THRML, this means that the nodes must all be of the same type.
Attributes:
nodes: tuple[_Node, ...]
The tuple of nodes that this block contains.Initializes a Block instance.
Convert block-local state to the global stacked representation.
The block representation is a list where block_state[i] contains the state of spec.blocks[i] and every node occupies index 0 of its leaf. The global representation is a shorter list (one entry per distinct PyTree structure) in which all blocks with the same structure are concatenated along their node axis.
class BlockSpec
This contains the necessary mappings for logging indices of states and node types.
This helps convert between block states and global states. A block state is a list of pytrees, where each pytree leaf has shape[0] = number of nodes in the block. The length of the block state is the number of blocks. The global state is a flattened version of this. Each pytree type is combined (regardless of which block they are in), to make a list of pytrees where each leaf shape[0] is the total number of nodes of that pytree shape. As an example, imagine an Ising model, every node is the same pytree (just a scalar array), as such the block state is a list of arrays where each array is the state of the block and the global state would be a length-1 list that contains an array of shape (total_nodes,).
Attributes:
blocks: list[Block]
The list of blocks this spec contains.all_block_sds: list[_PyTreeStruct]
A SD is a single _PyTreeStruct. Each node/block has only one SD associated with it, but each node can have neighbors of many types. This is the SD of each block (in the same order as blocks, this internal ordering is quite important for bookkeeping). This list is just the list of SDs for each block (and thus has length = len(blocks)).global_sd_order: list[_PyTreeStruct]
The list of SDs, providing a SoT for the global ordering.sd_index_map: dict[_PyTreeStruct, int]
A dictionary mapping the SD to an integer in the global_sd_order. This is like calling .index on it.node_global_location_map: dict[AbstractNode, tuple[int, int]]
A dictionary mapping a given node to a tuple. That tuple contains the global index (i.e. which element in the global list it is in) and the relative position in that pytree. That is to say, you can get the state of the node via map(x[tuple[1]], global_repr[tuple[0]]).block_to_global_slice_spec: list[list[int]]
A list over unique SDs (so length global_sd_order), where each list inside this is the list over blocks which contain that pytree. E.g. [[0, 1], [2]] indicates that blocks[0] and blocks[1] are both of pytree SD 0.node_shape_dtypes: dict[Type[AbstractNode], _PyTreeStruct]
A dictionary mapping node types to hashable _PyTreeStruct.node_shape_struct: dict[Type[AbstractNode], PyTree[jax.ShapeDtypeStruct]]
A dictionary mapping node types to pytrees of JAX-shaped dtype structs (just for user access, since the keys aren't hashable that creates issues for JAX in other areas.)Create a BlockSpec from blocks.
Based on the information passed in via node_shape_dtypes, determine the minimal global state that can be used to represent the blocks.
Extract the states for a subset of blocks from a global state.
Locate a contiguous set of nodes inside the global state.
Allocate a zero-initialised block state.
Check that a state is what it should be given some blocks and node shape/dtypes.
class BlockGibbsSpec(BlockSpec)
A BlockGibbsSpec is a type of [BlockSpec] which contains additional information on free and clamped blocks.
This entity also supports SuperBlocks, which are merely groups of blocks which are sampled at the same time algorithmically, but not programmatically. That is to say, superblock = (block1, block2) means that the states input to block1 and block2 are the same, but they are not executed at the same time. This may be because they are the same color on a graph, but require vastly different sampling methods such that JAX SIMD approaches are not feasible to parallelize them.
Attributes:
free_blocks: list[Block]
The list of free blocks (in order).sampling_order: list[list[int]]
A list of len(superblocks) lists, where each sampling_order[i] is the index of free_blocks to sample. Sampling is done by iterating over this order and sampling each sublist of free blocks at the same algorithmic time.clamped_blocks: list[Block]
The list of clamped blocks.superblocks: list[tuple[Block, ...]]
The list of superblocks.Create a Gibbs specification from free and clamped blocks.
class BlockSamplingProgram(eqx.Module)
A PGM block-sampling program.
This class encapsulates everything that is needed to run a PGM block sampling program in THRML. per_block_interactions and per_block_interaction_active are parallel to the free blocks in gibbs_spec, and their members are passed directly to a sampler↗ when the state of the corresponding free block is being updated during a sampling program. per_block_interaction_global_inds and per_block_interaction_global_slices are also parallel to the free blocks, and are used to slice the global state of the program to produce the state information required to update the state of each block alongside the static information contained in the interactions.
Attributes:
gibbs_spec: BlockGibbsSpec
A division of some PGM into free and clamped blocks.samplers: list[AbstractConditionalSampler]
A sampler↗ to use to update every free block in gibbs_spec.per_block_interactions: list[list[PyTree]]
All the interactions that touch each free block in gibbs_spec.per_block_interaction_active: list[list[Array]]
Indicates which interactions are real and which interactions are not part of the model and have been added to pad data structures so that they can be rectangular.per_block_interaction_global_inds: list[list[list[int]]]
How to find the information required to update each block within the global state list.per_block_interaction_global_slices: list[list[list[Array]]]
How to slice each array in the global state list to find the information required to update each block.Construct a BlockSamplingProgram.
Perform one iteration of sampling, visiting every block.
Samples a single block within a Gibbs sampling program based on the current states and program configurations.
It extracts neighboring states, processes required data, and applies a sampling function to generate output samples.
class SamplingSchedule
Represents a sampling schedule for a process.
Attributes:
n_warmup: int
The number of warmup steps to run before collecting samples.n_samples: int
The number of samples to collect.steps_per_sample: int
The number of steps to run between each sample.Convenience wrapper to collect state information for nodes_to_sample only.
Internally builds a thrml.StateObserver↗, runs thrml.sample_with_observation, and returns a stacked tensor of shape (schedule.n_samples, ...).
Run the full chain and call an Observer after every recorded sample.