Usage

Basic Usage

from accmt import AcceleratorModule, Trainer, HyperParameters

class ExampleModule(AcceleratorModule):
    def __init__(self):
        self.model = ...
        # self.model is required.

    def training_step(self, batch):
        x, y = batch
        # ...
        return train_loss

    def validation_step(self, key, batch):
        x, y = batch
        # ...
        return {
            "loss": val_loss,
            # any other metric...
        }

if __name__ == "__main__":
    module = ExampleModule()

    trainer = Trainer(
        hps_config=HyperParameters(epochs=2),
        model_path="model_folder",
    )

    train_dataset = ...
    val_dataset = ...

    trainer.fit(module, train_dataset, val_dataset)

To run training on multiple GPUs, you can use the following command:

accmt launch train.py

Advanced Usage

For more advanced usage, please refer to the API documentation.