# 1. Hello World 入门示例
import gradio as gr
# 定义一个最简单的函数
def greet(name: str):
return f"Hello, {name}!"
# 使用 Interface 快速构建应用
demo = gr.Interface(
fn=greet, # 绑定函数
inputs=gr.Textbox(label="请输入名字"), # 输入组件:单行文本框
outputs=gr.Textbox(label="问候语"), # 输出组件:单行文本框
title="Gradio Hello World"
)
demo.launch(share=True) # 启动服务,并生成公网临时链接
# 组件解释:
- Textbox:输入或输出文本,支持
lines参数控制多行。 - Interface:最基础的封装方式,适合单函数输入输出。
# 2. 加载 PyTorch .pth 模型并结合 Gradio
这里我们以 ResNet18 二分类模型(猫 vs 狗) 为例:
import torch
import torchvision.transforms as T
from torchvision.models import resnet18
from PIL import Image
import gradio as gr
# 1. 加载模型结构
model = resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 修改输出层为二分类
# 2. 加载训练好的权重(model.pth)
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval() # 切换到推理模式
# 3. 定义图像预处理(与训练时一致)
transform = T.Compose([
T.Resize((224, 224)), # 缩放图片
T.ToTensor(), # 转为 Tensor
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # 标准化
])
# 4. 推理函数
def predict(img: Image.Image):
x = transform(img).unsqueeze(0) # 扩展 batch 维度 [1,3,224,224]
with torch.no_grad():
logits = model(x) # 前向传播
probs = torch.softmax(logits, dim=1).squeeze().tolist() # 转换为概率
return {"cat": probs[0], "dog": probs[1]} # 输出字典
# 5. Gradio 封装
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="上传图片"), # 输入组件:图像
outputs=gr.Label(num_top_classes=2, label="预测结果"), # 输出组件:标签(带概率条)
title="猫狗分类"
)
demo.launch()
# 组件解释:
- Image:支持上传或摄像头输入,
type="pil"表示输出为 PIL Image。 - Label:显示分类概率,自动渲染柱状图。
# 3. 使用 Blocks 构建多步骤界面
with gr.Blocks() as demo:
with gr.Row(): # 横向布局
img = gr.Image(type="pil", label="上传图片")
btn = gr.Button("开始分类") # 按钮
out = gr.Label(num_top_classes=2, label="预测结果")
# 点击按钮时执行预测
btn.click(fn=predict, inputs=img, outputs=out)
demo.launch()
# 组件解释:
- Blocks:更灵活的 UI 构建方式。
- Row / Column:布局容器。
- Button:可绑定点击事件。
# 4. 扩展功能:显示 Top-K + 原图
def predict_with_img(img: Image.Image):
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1).squeeze()
classes = ["cat", "dog"]
result = {cls: float(probs[i]) for i, cls in enumerate(classes)}
return result, img # 返回两个结果
with gr.Blocks() as demo:
with gr.Row():
img = gr.Image(type="pil", label="上传图片")
btn = gr.Button("开始分类")
with gr.Row():
out_probs = gr.Label(num_top_classes=2, label="预测概率")
out_img = gr.Image(label="原始图片")
btn.click(fn=predict_with_img, inputs=img, outputs=[out_probs, out_img])
demo.launch()
# 解释:
- Gradio 函数可以返回多个输出,并绑定到多个组件。
Label用来显示概率条,Image再显示原图,增强可解释性。
# 5. 常用组件一览
- Textbox:输入文本
- Image:上传或拍照输入图像
- Label:展示分类结果与概率
- Button:触发事件
- Chatbot:用于构建对话界面(常见于 LLM Demo)
- File:上传或下载文件
- Slider / Dropdown / Checkbox:用于选择参数
# 6. 队列与并发控制
demo.queue(concurrency_count=2).launch()
- queue():开启请求队列,避免推理时间过长导致前端超时。
- concurrency_count:同时处理的请求数。
# 7. 自定义样式 / 鉴权 / 安全
with gr.Blocks(css="body {background-color: #f0f8ff;}") as demo:
img = gr.Image(type="pil")
out = gr.Label()
img.change(predict, inputs=img, outputs=out)
demo.launch(auth=("admin", "1234"), allowed_paths=["/safe_dir"])
- css=:注入自定义 CSS。
- auth=:开启简单的用户名/密码鉴权。
- allowed_paths=:限制访问目录,避免文件泄露。
# 8. 最终完整应用(整合版)
import gradio as gr
import torch
import torchvision.transforms as T
from torchvision.models import resnet18
from PIL import Image
# 模型加载
model = resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()
# 预处理
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 推理函数
def predict(img: Image.Image):
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1).squeeze()
classes = ["cat", "dog"]
return {cls: float(probs[i]) for i, cls in enumerate(classes)}
# 构建应用
with gr.Blocks(css="body {font-family: Arial; background:#fafafa;}") as demo:
with gr.Row():
img = gr.Image(type="pil", label="上传图片")
btn = gr.Button("开始分类")
with gr.Row():
out = gr.Label(num_top_classes=2, label="预测结果")
btn.click(predict, inputs=img, outputs=out)
# 开启队列 + 鉴权
demo.queue(concurrency_count=2).launch(auth=("user", "pass"))
