validate_model_outputs¶
- torch_sim.models.interface.validate_model_outputs(model, device, dtype)[source]¶
Validate the outputs of a model implementation against the interface requirements.
Runs a series of tests to ensure a model implementation correctly follows the ModelInterface contract. The tests include creating sample systems, running forward passes, and verifying output shapes and consistency.
- Parameters:
model (ModelInterface) – Model implementation to validate.
device (device) – Device to run the validation tests on.
dtype (dtype) – Data type to use for validation tensors.
- Raises:
AssertionError – If the model doesn’t conform to the required interface, including issues with output shapes, types, or behavior consistency.
- Return type:
None
Example:
# Create a new model implementation model = MyCustomModel(device=torch.device("cuda")) # Validate that it correctly implements the interface validate_model_outputs(model, device=torch.device("cuda"), dtype=torch.float64)
Notes
This validator creates small test systems (silicon and iron) for validation. It tests both single and multi-batch processing capabilities.