import torch
def predict(model_path, input_data):
# 1. 初始化模型结构 (默认是 train 模式)
model = MyDeepLearningModel()
# 2. 加载权重 (这只恢复了参数,没改模式)
model.load_state_dict(torch.load(model_path))
# ==========================================
# 关键步骤来了!
# ==========================================
# 3. 切换到评估模式 (固定住 Dropout 和 BatchNorm)
model.eval()
# 4. 关闭梯度计算 (省显存,加速)
with torch.no_grad():
# 数据预处理...
output = model(input_data)
# 获取结果...
return output