基于Huggingface Accelerate的DDP训练

一方天地 / 2024-02-08 / 原文

# -*- coding: utf-8 -*-

"""" This document is a simple Demo for DDP Image Classification """

from typing import Callable
from argparse import ArgumentParser, Namespace

import torch
from torch.backends import cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed


def parse_args() -> Namespace:
    """Handling command-line input."""
    parser = ArgumentParser()

    # 数据集路径
    parser.add_argument(
        "-d",
        "--dataset",
        action="store",
        default="/dev/shm/dataset",
        type=str,
        help="Dataset folder.",
    )

    # 训练轮数
    parser.add_argument(
        "-e",
        "--epochs",
        action="store",
        default=248,
        type=int,
        help="Number of epochs to train.",
    )

    # Mini Batch大小
    parser.add_argument(
        "-bs",
        "--batch-size",
        action="store",
        default=128,
        type=int,
        help="Size of mini batch.",
    )

    # 优化器选择
    parser.add_argument(
        "-opt",
        "--optimizer",
        action="store",
        default="SGD",
        type=str,
        choices=["Adam", "SGD"],
        help="Optimizer used to train the model.",
    )

    # 初始学习率
    parser.add_argument(
        "-lr",
        "--learning-rate",
        action="store",
        default=2e-3,
        type=float,
        help="Learning rate.",
    )

    # 随机数种子
    parser.add_argument(
        "-s",
        "--seed",
        action="store",
        default=0,
        type=int,
        help="Random Seed.",
    )
    return parser.parse_args()


def prepare_model(num_classes: int = 1000) -> torch.nn.Module:
    """ResNet18,并替换FC层"""
    with accelerator.local_main_process_first():
        model: resnet.ResNet = resnet.resnet18(
            weights=resnet.ResNet18_Weights.DEFAULT
        )
    # 对于CIFAR数据集,ResNet-18将首层的7x7卷积核换成了3x3卷积核(参数量基本不变)
    model.conv1 = torch.nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    )
    if num_classes != 1000:
        model.fc = torch.nn.Linear(512, num_classes)

    total_params = sum([param.nelement() for param in model.parameters()])

    accelerator.print(f"#params: {total_params / 1e6}M")

    return model


def prepare_dataset(folder: str):
    """采用CIFAR-10数据集"""
    normalize_transform = transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    )

    with accelerator.local_main_process_first():
        train_data = CIFAR10(
            folder,
            train=True,
            transform=transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(0.25),
                    transforms.AutoAugment(
                        transforms.AutoAugmentPolicy.CIFAR10
                    ),
                    transforms.ToTensor(),
                    normalize_transform,
                ]
            ),
            download=accelerator.is_local_main_process,
        )

        test_data = CIFAR10(
            folder,
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), normalize_transform]
            ),
            download=accelerator.is_local_main_process,
        )

    train_eval_data = CIFAR10(
        folder,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), normalize_transform]
        ),
    )
    return train_data, train_eval_data, test_data


def get_data_loader(
    batch_size: int,
    train_data: Dataset,
    train_eval_data: Dataset,
    test_data: Dataset,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    """获取DataLoader"""
    train_loader: DataLoader = DataLoader(
        train_data,
        batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=2 if accelerator.num_processes == 1 else 0,
    )
    train_eval_loader: DataLoader = DataLoader(
        train_eval_data,
        batch_size * 2,
        shuffle=False,
        pin_memory=True,
        num_workers=2 if accelerator.num_processes == 1 else 0,
    )
    test_loader: DataLoader = DataLoader(
        test_data,
        batch_size * 2,
        shuffle=False,
        pin_memory=True,
        num_workers=2 if accelerator.num_processes == 1 else 0,
    )
    return accelerator.prepare(train_loader, train_eval_loader, test_loader)


@torch.enable_grad()
def train_epoch(
    model: torch.nn.Module,
    loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
) -> None:
    """训练一轮"""
    model.train()
    dataloader_with_bar = tqdm(
        dataloader, disable=(not accelerator.is_local_main_process)
    )
    for source, targets in dataloader_with_bar:
        optimizer.zero_grad()
        output: torch.Tensor = model(source)
        loss = loss_func(output, targets)
        accelerator.backward(loss)
        optimizer.step()


@torch.no_grad()
def eval_epoch(
    model: torch.nn.Module,
    loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    dataloader: DataLoader,
) -> tuple[float, float]:
    """在指定测试集上测试模型的损失和准确率"""
    model.eval()
    dataloader_with_bar = tqdm(
        dataloader, disable=(not accelerator.is_local_main_process)
    )
    correct_sum, loss_sum, cnt_samples = 0, 0.0, 0
    for source, targets in dataloader_with_bar:
        output: torch.Tensor = model(source)
        loss = loss_func(output, targets)

        prediction: torch.Tensor = accelerator.gather_for_metrics(
            output.argmax(dim=1) == targets
        )  # type: ignore
        correct_sum += prediction.sum().item()
        loss_sum += loss.item()
        cnt_samples += len(prediction)
    return loss_sum / len(dataloader), correct_sum / cnt_samples


def main(args: Namespace):
    """训练的主函数"""
    set_seed(args.seed)
    model = prepare_model(10)
    train_data, train_eval_data, test_data = prepare_dataset(args.dataset)
    train_loader, train_eval_loader, test_loader = get_data_loader(
        args.batch_size, train_data, train_eval_data, test_data
    )

    optimizer: torch.optim.Optimizer = (
        torch.optim.SGD(
            model.parameters(),
            args.learning_rate,
            momentum=0.90,
            weight_decay=2e-2,
        )
        if args.optimizer != "SGD"
        else torch.optim.Adam(model.parameters(), args.learning_rate)
    )

    loss_func = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
    scheduler: CosineAnnealingWarmRestarts = CosineAnnealingWarmRestarts(
        optimizer, 8, 2
    )
    model, optimizer, loss_func, scheduler = accelerator.prepare(
        model, optimizer, loss_func, scheduler
    )

    best_acc = 0

    log_file = open("log.csv", "wt")
    if accelerator.is_local_main_process:
        print(
            "epoch,train_loss,train_acc,val_loss,val_acc,learning_rate",
            file=log_file,
        )
        log_file.flush()

    for epoch in range(args.epochs + 1):
        accelerator.print(
            f"Epoch {epoch}/{args.epochs}",
            f"(lr={optimizer.param_groups[-1]['lr']}):",
        )

        # 训练模型
        if epoch != 0:
            train_epoch(model, loss_func, train_loader, optimizer)

        accelerator.wait_for_everyone()

        # 在训练集和测试集上评估模型
        train_loss, train_acc = eval_epoch(model, loss_func, train_eval_loader)
        val_loss, val_acc = eval_epoch(model, loss_func, test_loader)
        accelerator.print(
            f"[ Training ] Acc: {train_acc * 100:.2f}% Loss: {train_loss:.4f}"
        )

        # 保存最佳权重
        accelerator.wait_for_everyone()
        if accelerator.is_local_main_process:
            print(
                epoch,
                train_loss,
                train_acc,
                val_loss,
                val_acc,
                optimizer.param_groups[-1]["lr"],
                sep=",",
                file=log_file,
            )
            log_file.flush()
            accelerator.save_model(model, "./weights/last")
            if val_acc > best_acc:
                best_acc = val_acc
                accelerator.save_model(model, "./weights/best")
        accelerator.wait_for_everyone()

        accelerator.print(
            f"[Validation] Acc: {val_acc * 100:.2f}%",
            f"Loss: {val_loss:.4f}",
            f"Best: {best_acc * 100:.2f}%",
        )

        if epoch != 0:
            scheduler.step()

    log_file.close()


if __name__ == "__main__":
    cudnn.benchmark = True
    accelerator = Accelerator()
    main(parse_args())