在深度学习训练中,我们常用很多超参数。Python 提供三个工具帮助管理:
- argparse:从命令行传参
- YAML:写配置文件
- easydict:让字典支持点号访问
结合它们,就能灵活地管理训练参数,代码更清晰、可复现。
# 1. 命令行是怎么调用 Python 脚本的?
假设你有一个脚本 train.py:
# train.py
print("Hello")
在命令行里:
python train.py
python:调用 Python 解释器train.py:要运行的脚本
如果你想给脚本传参:
python train.py --batch_size 64 --lr 0.01
这里 --batch_size 64、--lr 0.01 就是命令行参数。Python 默认不会解析它们,需要用 argparse。
# 2. argparse 详解
import argparse
parser = argparse.ArgumentParser(description="训练脚本示例")
- 创建解析器
description会显示在--help中
# 2.1 添加参数
parser.add_argument("--batch_size", type=int, default=32, help="训练批大小")
parser.add_argument("--lr", type=float, default=0.001, help="学习率")
parser.add_argument("--use_gpu", action="store_true", help="是否使用 GPU")
args = parser.parse_args()
# 参数讲解
| 参数 | 说明 |
|---|---|
--batch_size |
参数名(命令行用 --batch_size 64) |
type=int |
自动转换类型,如果传入非 int 会报错 |
default=32 |
如果命令行没传这个参数,就用默认值 32 |
help="..." |
在 python train.py -h 中显示帮助信息 |
action="store_true" |
布尔开关,不传为 False,传了为 True |
# 运行示例
python train.py --batch_size 64 --use_gpu
print(args.batch_size) # 64
print(args.use_gpu) # True
- 不写
--use_gpu:False - 写了
--use_gpu:True
# 2.2 其他常用 action
store_true/store_false:布尔开关append:可以多次传参,把值收集成列表count:统计参数出现次数(例如-v打印更详细日志)
parser.add_argument("-v", "--verbose", action="count", default=0, help="增加日志详细级别")
运行:
python train.py -vvv
print(args.verbose) # 3
# 2.3 位置参数(必填)
parser.add_argument("dataset_path", help="数据集路径")
- 命令行必须写
python train.py ./data
print(args.dataset_path) # ./data
# 2.4 限定选项 choices
parser.add_argument("--optimizer", choices=["sgd", "adam"], default="adam", help="优化器类型")
- 只能是
"sgd"或"adam" - 传其他值 argparse 会报错
# 2.5 多值参数 nargs
parser.add_argument("--layers", type=int, nargs="+", help="每层神经元数量")
运行:
python train.py --layers 64 128 256
print(args.layers) # [64, 128, 256]
nargs='+':一个或多个nargs='*':零个或多个
# 2.6 查看帮助
python train.py -h
输出:
usage: train.py [-h] [--batch_size BATCH_SIZE] [--lr LR] [--use_gpu] ...
训练脚本示例
optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
训练批大小 (default: 32)
--lr LR 学习率 (default: 0.001)
--use_gpu 是否使用 GPU
# 3. YAML 讲解
YAML 是一种人类可读的配置文件格式,常用于存储训练参数。
# 3.1 示例 config.yaml
train:
batch_size: 64
learning_rate: 0.001
num_epochs: 10
model:
name: "resnet18"
num_classes: 100
# 3.2 Python 读取
import yaml
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
print(config["train"]["batch_size"]) # 64
推荐使用
safe_load,防止执行不安全代码
# 4. easydict 讲解
easydict 可以把字典转换成支持点号访问的对象,这样访问嵌套参数更清晰:
from easydict import EasyDict as edict
cfg_dict = {
"train": {"batch_size": 64, "learning_rate": 0.001, "num_epochs": 10},
"model": {"name": "resnet18", "num_classes": 100}
}
cfg = edict(cfg_dict)
print(cfg.train.batch_size) # 64
print(cfg.model.name) # resnet18
# 也可以像普通字典一样访问
print(cfg["train"]["learning_rate"]) # 0.001
# 5. argparse + YAML + easydict 结合使用
思路:
- argparse 解析命令行参数
- 读取 YAML 配置
- 转换为
EasyDict - 用命令行参数覆盖 YAML 配置
完整示例:
import argparse
import yaml
from easydict import EasyDict as edict
def get_config():
parser = argparse.ArgumentParser(description="训练脚本")
parser.add_argument("--config", type=str, default="config.yaml", help="配置文件路径")
parser.add_argument("--batch_size", type=int, help="覆盖 YAML 的批大小")
parser.add_argument("--learning_rate", type=float, help="覆盖 YAML 的学习率")
parser.add_argument("--num_epochs", type=int, help="覆盖 YAML 的训练轮数")
parser.add_argument("--use_gpu", action="store_true", help="是否使用 GPU")
args = parser.parse_args()
# 读取 YAML 并转换为 EasyDict
with open(args.config, "r") as f:
cfg = edict(yaml.safe_load(f))
# 命令行覆盖
if args.batch_size is not None:
cfg.train.batch_size = args.batch_size
if args.learning_rate is not None:
cfg.train.learning_rate = args.learning_rate
if args.num_epochs is not None:
cfg.train.num_epochs = args.num_epochs
cfg.train.use_gpu = args.use_gpu
return cfg
def main():
cfg = get_config()
print("最终训练配置:")
print(cfg)
# 使用配置
print(f"训练 {cfg.model.name},批大小 {cfg.train.batch_size},学习率 {cfg.train.learning_rate}")
if __name__ == "__main__":
main()
# 5.1 使用方法
- 默认 YAML 配置:
python train.py
- 指定 YAML 文件:
python train.py --config configs/exp1.yaml
- 命令行覆盖部分参数:
python train.py --batch_size 128 --learning_rate 0.0005 --use_gpu
访问参数时可以直接用点号:
cfg.train.batch_size
cfg.model.name
训练循环里使用:
for epoch in range(cfg.train.num_epochs):
for x, y in dataloader(batch_size=cfg.train.batch_size):
...
# 6. 总结
- argparse:灵活修改、临时覆盖参数
- YAML:记录实验参数、便于复现
- easydict:让配置访问更直观、简洁
- 结合使用:YAML 保存默认参数,argparse 命令行覆盖,easydict 优化访问
