validate_model_outputs

torch_sim.models.interface.validate_model_outputs(model, device, dtype)[source]

Validate the outputs of a model implementation against the interface requirements.

Runs a series of tests to ensure a model implementation correctly follows the ModelInterface contract. The tests include creating sample systems, running forward passes, and verifying output shapes and consistency.

Parameters:
  • model (ModelInterface) – Model implementation to validate.

  • device (device) – Device to run the validation tests on.

  • dtype (dtype) – Data type to use for validation tensors.

Raises:

AssertionError – If the model doesn’t conform to the required interface, including issues with output shapes, types, or behavior consistency.

Return type:

None

Example:

# Create a new model implementation
model = MyCustomModel(device=torch.device("cuda"))

# Validate that it correctly implements the interface
validate_model_outputs(model, device=torch.device("cuda"), dtype=torch.float64)

Notes

This validator creates small test systems (silicon and iron) for validation. It tests both single and multi-batch processing capabilities.