在深度学习训练中,我们常用很多超参数。Python 提供三个工具帮助管理:

  1. argparse:从命令行传参
  2. YAML:写配置文件
  3. easydict:让字典支持点号访问

结合它们,就能灵活地管理训练参数,代码更清晰、可复现。


# 1. 命令行是怎么调用 Python 脚本的?

假设你有一个脚本 train.py

# train.py
print("Hello")

在命令行里:

python train.py
  • python:调用 Python 解释器
  • train.py:要运行的脚本

如果你想给脚本传参:

python train.py --batch_size 64 --lr 0.01

这里 --batch_size 64--lr 0.01 就是命令行参数。Python 默认不会解析它们,需要用 argparse


# 2. argparse 详解

import argparse

parser = argparse.ArgumentParser(description="训练脚本示例")
  • 创建解析器
  • description 会显示在 --help

# 2.1 添加参数

parser.add_argument("--batch_size", type=int, default=32, help="训练批大小")
parser.add_argument("--lr", type=float, default=0.001, help="学习率")
parser.add_argument("--use_gpu", action="store_true", help="是否使用 GPU")
args = parser.parse_args()

# 参数讲解

参数 说明
--batch_size 参数名(命令行用 --batch_size 64
type=int 自动转换类型,如果传入非 int 会报错
default=32 如果命令行没传这个参数,就用默认值 32
help="..." python train.py -h 中显示帮助信息
action="store_true" 布尔开关,不传为 False,传了为 True

# 运行示例

python train.py --batch_size 64 --use_gpu
print(args.batch_size)  # 64
print(args.use_gpu)     # True
  • 不写 --use_gpuFalse
  • 写了 --use_gpuTrue

# 2.2 其他常用 action

  • store_true / store_false:布尔开关
  • append:可以多次传参,把值收集成列表
  • count:统计参数出现次数(例如 -v 打印更详细日志)
parser.add_argument("-v", "--verbose", action="count", default=0, help="增加日志详细级别")

运行:

python train.py -vvv
print(args.verbose)  # 3

# 2.3 位置参数(必填)

parser.add_argument("dataset_path", help="数据集路径")
  • 命令行必须写
python train.py ./data
print(args.dataset_path)  # ./data

# 2.4 限定选项 choices

parser.add_argument("--optimizer", choices=["sgd", "adam"], default="adam", help="优化器类型")
  • 只能是 "sgd""adam"
  • 传其他值 argparse 会报错

# 2.5 多值参数 nargs

parser.add_argument("--layers", type=int, nargs="+", help="每层神经元数量")

运行:

python train.py --layers 64 128 256
print(args.layers)  # [64, 128, 256]
  • nargs='+':一个或多个
  • nargs='*':零个或多个

# 2.6 查看帮助

python train.py -h

输出:

usage: train.py [-h] [--batch_size BATCH_SIZE] [--lr LR] [--use_gpu] ...

训练脚本示例

optional arguments:
  -h, --help            show this help message and exit
  --batch_size BATCH_SIZE
                                                训练批大小 (default: 32)
    --lr LR               学习率 (default: 0.001)
  --use_gpu             是否使用 GPU

# 3. YAML 讲解

YAML 是一种人类可读的配置文件格式,常用于存储训练参数。

# 3.1 示例 config.yaml

train:
  batch_size: 64
  learning_rate: 0.001
  num_epochs: 10

model:
  name: "resnet18"
  num_classes: 100

# 3.2 Python 读取

import yaml

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

print(config["train"]["batch_size"])  # 64

推荐使用 safe_load,防止执行不安全代码


# 4. easydict 讲解

easydict 可以把字典转换成支持点号访问的对象,这样访问嵌套参数更清晰:

from easydict import EasyDict as edict

cfg_dict = {
    "train": {"batch_size": 64, "learning_rate": 0.001, "num_epochs": 10},
    "model": {"name": "resnet18", "num_classes": 100}
}

cfg = edict(cfg_dict)

print(cfg.train.batch_size)      # 64
print(cfg.model.name)            # resnet18

# 也可以像普通字典一样访问
print(cfg["train"]["learning_rate"])  # 0.001

# 5. argparse + YAML + easydict 结合使用

思路:

  1. argparse 解析命令行参数
  2. 读取 YAML 配置
  3. 转换为 EasyDict
  4. 用命令行参数覆盖 YAML 配置

完整示例:

import argparse
import yaml
from easydict import EasyDict as edict

def get_config():
    parser = argparse.ArgumentParser(description="训练脚本")
    parser.add_argument("--config", type=str, default="config.yaml", help="配置文件路径")
    parser.add_argument("--batch_size", type=int, help="覆盖 YAML 的批大小")
    parser.add_argument("--learning_rate", type=float, help="覆盖 YAML 的学习率")
    parser.add_argument("--num_epochs", type=int, help="覆盖 YAML 的训练轮数")
    parser.add_argument("--use_gpu", action="store_true", help="是否使用 GPU")
    args = parser.parse_args()

    # 读取 YAML 并转换为 EasyDict
    with open(args.config, "r") as f:
        cfg = edict(yaml.safe_load(f))

    # 命令行覆盖
    if args.batch_size is not None:
        cfg.train.batch_size = args.batch_size
    if args.learning_rate is not None:
        cfg.train.learning_rate = args.learning_rate
    if args.num_epochs is not None:
        cfg.train.num_epochs = args.num_epochs
    cfg.train.use_gpu = args.use_gpu

    return cfg

def main():
    cfg = get_config()
    print("最终训练配置:")
    print(cfg)

    # 使用配置
    print(f"训练 {cfg.model.name},批大小 {cfg.train.batch_size},学习率 {cfg.train.learning_rate}")

if __name__ == "__main__":
    main()

# 5.1 使用方法

  1. 默认 YAML 配置:
python train.py
  1. 指定 YAML 文件:
python train.py --config configs/exp1.yaml
  1. 命令行覆盖部分参数:
python train.py --batch_size 128 --learning_rate 0.0005 --use_gpu

访问参数时可以直接用点号:

cfg.train.batch_size
cfg.model.name

训练循环里使用:

for epoch in range(cfg.train.num_epochs):
    for x, y in dataloader(batch_size=cfg.train.batch_size):
        ...

# 6. 总结

  • argparse:灵活修改、临时覆盖参数
  • YAML:记录实验参数、便于复现
  • easydict:让配置访问更直观、简洁
  • 结合使用:YAML 保存默认参数,argparse 命令行覆盖,easydict 优化访问