Skip to content

PyTorch Lightning 恢复训练进度

标签
AI
AI/PyTorch
开发/Python/PyTorch
开发/Python/PyTorchLightning
开发/Python
字数
296 字
阅读时间
2 分钟

下面的代码基于这个镜像进行作业:

huggingface/transformers-pytorch-deepspeed-latest-gpu

在镜像中运行下面的命令安装 lightning

shell
pip install lightning

添加修改代码如下:

python
import os
import glob

def find_latest_checkpoint():
    # 获取所有版本目录
    version_dirs = glob.glob('lightning_logs/version_*')
    # 按版本号排序
    version_dirs.sort(key=lambda x: int(x.split('_')[-1]))
    # 获取最新版本目录
    latest_version_dir = version_dirs[-1]
    # 获取该版本目录下的所有检查点文件
    ckpt_files = glob.glob(os.path.join(latest_version_dir, 'checkpoints', '*.ckpt'))
    # 按步骤号排序
    ckpt_files.sort(key=lambda x: int(x.split('=')[-1].split('.')[0]))
    # 获取最新的检查点文件
    latest_ckpt_file = ckpt_files[-1]

    return latest_ckpt_file

def main():
    # rest of the code...

    latest_ckpt_path = find_latest_checkpoint()

    if os.path.exists(latest_ckpt_path):
        logger.info(f'latest_ckpt_path detected, resuming from {latest_ckpt_path}')
        trainer.fit(model, mnist_data, ckpt_path=latest_ckpt_path)
    else:
        trainer.fit(model, mnist_data)

if __name__ == '__main__':
    main()

之后运行下面的命令就可以自动检测最新的检查点文件并恢复训练进度啦!

shell
torchrun main.py

贡献者

文件历史

撰写