# 1. Hello World 入门示例

import gradio as gr

# 定义一个最简单的函数
def greet(name: str):
    return f"Hello, {name}!"

# 使用 Interface 快速构建应用
demo = gr.Interface(
    fn=greet,  # 绑定函数
    inputs=gr.Textbox(label="请输入名字"),  # 输入组件:单行文本框
    outputs=gr.Textbox(label="问候语"),    # 输出组件:单行文本框
    title="Gradio Hello World"
)

demo.launch(share=True)  # 启动服务,并生成公网临时链接

# 组件解释:

  • Textbox:输入或输出文本,支持 lines 参数控制多行。
  • Interface:最基础的封装方式,适合单函数输入输出。

# 2. 加载 PyTorch .pth 模型并结合 Gradio

这里我们以 ResNet18 二分类模型(猫 vs 狗) 为例:

import torch
import torchvision.transforms as T
from torchvision.models import resnet18
from PIL import Image
import gradio as gr

# 1. 加载模型结构
model = resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)  # 修改输出层为二分类

# 2. 加载训练好的权重(model.pth)
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()  # 切换到推理模式

# 3. 定义图像预处理(与训练时一致)
transform = T.Compose([
    T.Resize((224, 224)),  # 缩放图片
    T.ToTensor(),          # 转为 Tensor
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])  # 标准化
])

# 4. 推理函数
def predict(img: Image.Image):
    x = transform(img).unsqueeze(0)  # 扩展 batch 维度 [1,3,224,224]
    with torch.no_grad():
        logits = model(x)  # 前向传播
        probs = torch.softmax(logits, dim=1).squeeze().tolist()  # 转换为概率
    return {"cat": probs[0], "dog": probs[1]}  # 输出字典
    

# 5. Gradio 封装
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="上传图片"),  # 输入组件:图像
    outputs=gr.Label(num_top_classes=2, label="预测结果"),  # 输出组件:标签(带概率条)
    title="猫狗分类"
)

demo.launch()

# 组件解释:

  • Image:支持上传或摄像头输入,type="pil" 表示输出为 PIL Image。
  • Label:显示分类概率,自动渲染柱状图。

# 3. 使用 Blocks 构建多步骤界面

with gr.Blocks() as demo:
    with gr.Row():  # 横向布局
        img = gr.Image(type="pil", label="上传图片")
        btn = gr.Button("开始分类")  # 按钮
    out = gr.Label(num_top_classes=2, label="预测结果")

    # 点击按钮时执行预测
    btn.click(fn=predict, inputs=img, outputs=out)

demo.launch()

# 组件解释:

  • Blocks:更灵活的 UI 构建方式。
  • Row / Column:布局容器。
  • Button:可绑定点击事件。

# 4. 扩展功能:显示 Top-K + 原图

def predict_with_img(img: Image.Image):
    x = transform(img).unsqueeze(0)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1).squeeze()
    classes = ["cat", "dog"]
    result = {cls: float(probs[i]) for i, cls in enumerate(classes)}
    return result, img  # 返回两个结果

with gr.Blocks() as demo:
    with gr.Row():
        img = gr.Image(type="pil", label="上传图片")
        btn = gr.Button("开始分类")
    with gr.Row():
        out_probs = gr.Label(num_top_classes=2, label="预测概率")
        out_img = gr.Image(label="原始图片")

    btn.click(fn=predict_with_img, inputs=img, outputs=[out_probs, out_img])

demo.launch()

# 解释:

  • Gradio 函数可以返回多个输出,并绑定到多个组件。
  • Label 用来显示概率条,Image 再显示原图,增强可解释性。

# 5. 常用组件一览

  • Textbox:输入文本
  • Image:上传或拍照输入图像
  • Label:展示分类结果与概率
  • Button:触发事件
  • Chatbot:用于构建对话界面(常见于 LLM Demo)
  • File:上传或下载文件
  • Slider / Dropdown / Checkbox:用于选择参数

# 6. 队列与并发控制

demo.queue(concurrency_count=2).launch()
  • queue():开启请求队列,避免推理时间过长导致前端超时。
  • concurrency_count:同时处理的请求数。

# 7. 自定义样式 / 鉴权 / 安全

with gr.Blocks(css="body {background-color: #f0f8ff;}") as demo:
    img = gr.Image(type="pil")
    out = gr.Label()
    img.change(predict, inputs=img, outputs=out)

demo.launch(auth=("admin", "1234"), allowed_paths=["/safe_dir"])
  • css=:注入自定义 CSS。
  • auth=:开启简单的用户名/密码鉴权。
  • allowed_paths=:限制访问目录,避免文件泄露。

# 8. 最终完整应用(整合版)

import gradio as gr
import torch
import torchvision.transforms as T
from torchvision.models import resnet18
from PIL import Image

# 模型加载
model = resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()

# 预处理
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# 推理函数
def predict(img: Image.Image):
    x = transform(img).unsqueeze(0)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1).squeeze()
    classes = ["cat", "dog"]
    return {cls: float(probs[i]) for i, cls in enumerate(classes)}

# 构建应用
with gr.Blocks(css="body {font-family: Arial; background:#fafafa;}") as demo:
    with gr.Row():
        img = gr.Image(type="pil", label="上传图片")
        btn = gr.Button("开始分类")
    with gr.Row():
        out = gr.Label(num_top_classes=2, label="预测结果")
    btn.click(predict, inputs=img, outputs=out)

# 开启队列 + 鉴权
demo.queue(concurrency_count=2).launch(auth=("user", "pass"))