一、课题背景

随着农业现代化的不断推进,传统的农业病虫害检测方法逐渐无法满足大规模农业生产的需求。病虫害的检测与防治不仅直接影响作物的产量和质量,而且在全球农业生产中占据着举足轻重的地位。大豆作为全球重要的粮食作物之一,其病虫害防治的智能化与自动化成为了研究的重点。

传统的病虫害检测主要依赖人工观察和传统的化学试剂检测,这不仅耗时且容易出现误判,特别是在大规模农田中,人工检测的成本和难度都非常高。近年来,随着深度学习技术的兴起,基于图像处理和模式识别的病虫害智能检测系统为农业提供了更高效、准确、低成本的解决方案。

深度学习,尤其是卷积神经网络(CNN)在图像识别中的应用,已经在医学、安防、自动驾驶等领域取得了显著成效。针对大豆病虫害的智能检测,基于深度学习的图像识别技术能够高效地从海量图像数据中提取特征,实现自动识别病虫害种类,并实时反馈病虫害发生的具体情况,从而为农民提供快速、准确的决策依据。

二、研究目的与意义

本课题旨在设计和实现一个基于深度学习的大豆病虫害智能检测系统,具体目标包括:

  1. 提升病虫害检测的精度与效率:通过深度学习技术,利用大豆病虫害的图像数据进行训练,构建一个能够准确检测病虫害类型和程度的智能系统,减少人工检查误差和漏检。
  2. 自动化检测与实时反馈:系统能够自动识别大豆病虫害,并通过实时反馈给农民或农业管理人员,帮助及时发现问题并采取相应的防治措施,从而提高大豆的产量和质量。
  3. 技术应用与推广:通过本系统的设计与实现,探索深度学习在农业中的应用,推动农业智能化、信息化的进程。
  4. 降低成本与提高可持续性:通过自动化和智能化的检测系统,减少传统人工巡检带来的高成本,同时也为实现更加环保、可持续的农业生产提供技术支持。
三、国内外研究现状

随着深度学习和计算机视觉技术的快速发展,国内外在病虫害检测领域取得了诸多研究成果。以下是相关研究的综述:

  1. 国内研究现状
  2. 国内许多科研机构和高校开始探索深度学习技术在农业病虫害检测中的应用。例如,华中农业大学、浙江大学等高校已开展了基于深度学习的大豆病虫害图像识别研究。研究者利用卷积神经网络(CNN)和生成对抗网络(GAN)等方法,结合大豆叶片的图像数据,提出了一些较为先进的病虫害检测方案。然而,这些研究多集中在数据集的构建、网络模型的训练和优化上,实际应用中系统的稳定性和推广性仍存在较大挑战。
  3. 国外研究现状
  4. 国外在农业病虫害检测领域的研究也取得了显著进展,尤其是在北美、欧洲等地区。研究者们多使用卷积神经网络(CNN)进行作物病虫害的识别与诊断,如美国的加州大学、法国的INRA研究所等通过深度学习对农作物病虫害进行了大量实验。这些研究主要集中在多个作物病害的检测上,但针对大豆作物病虫害的智能化识别系统研究相对较少,尤其是针对大豆特有病虫害的智能检测系统。
四、研究内容与技术路线

本研究主要包括以下几个方面的内容:

  1. 大豆病虫害图像数据集的构建
  2. 在深度学习的应用中,数据集的质量是至关重要的。本课题将通过采集不同大豆品种、不同病虫害种类的图像数据,构建一个完整的大豆病虫害图像数据集。数据集的构建过程中,需对病虫害的种类、特征、严重程度等进行标注,为模型训练提供准确的标签信息。
  3. 基于卷积神经网络(CNN)的病虫害图像识别模型设计与训练
  4. 本研究将采用卷积神经网络(CNN)作为主要的深度学习模型,通过对大量病虫害图像的学习,自动提取图像中的特征信息,并完成病虫害的识别任务。在网络模型设计时,将考虑深度学习网络的深度、激活函数、优化算法等要素,优化模型的性能。
  5. 模型优化与实时检测系统的实现
  6. 为了提高系统的检测效率与精度,本文将在训练过程中采用数据增强、迁移学习等技术,以克服数据不平衡问题。此外,通过集成学习等技术进一步提升系统的准确性。系统的实时性要求能够保证用户通过系统实时获取病虫害的检测结果,并能够自动化生成防治方案。
  7. 系统实现与部署
  8. 系统将采用前端与后端结合的架构,前端实现病虫害图像的采集和上传,后端进行图像处理、深度学习模型推理与结果展示。最终,通过开发一个适用于农业生产的便捷界面,农民可通过该系统获得大豆病虫害的智能检测服务。
五、预期目标与成果

本研究的预期目标如下:

  1. 高精度病虫害智能检测模型:设计并训练一个能够识别大豆病虫害类型、程度的高精度深度学习模型,并能够在实际应用中具备较强的鲁棒性。
  2. 智能检测系统实现:实现一个智能检测系统,能够自动化识别大豆病虫害,并提供实时反馈与决策支持。
  3. 推广应用与社会效益:通过系统的实现,为农业生产提供技术支持,推动农业生产的智能化、信息化发展,降低农业生产成本,提高作物产量与质量。
六、研究方法与技术实现

本课题的研究方法和技术实现将采用以下技术方案:

  1. 数据采集与预处理:利用高清相机、无人机等设备采集大豆病虫害的图像数据,进行去噪、标注、分割等数据预处理工作,为模型的训练提供高质量的数据集。
  2. 深度学习技术:采用卷积神经网络(CNN)作为深度学习模型,结合数据增强、迁移学习等技术优化模型训练过程,提升模型的准确性与泛化能力。
  3. 前端与后端技术:前端使用基于React、Vue等框架的Web技术,后端使用Python语言结合Flask或Django框架搭建,实现病虫害检测系统的接口和用户界面的开发。
七、计划进度与实施步骤

第一阶段(1-2个月) 文献调研与数据集构建完成文献调研,收集并标注大豆病虫害图像数据

第二阶段(3-4个月) 深度学习模型设计与训练构建卷积神经网络模型,进行模型训练与优化

第三阶段(5-6个月) 系统实现与测试完成系统开发与前端后端集成,进行测试与优化

第四阶段(7-8个月) 项目总结与答辩准备撰写毕业设计报告,准备答辩材料

核心设计部分(仅供学习和参考)

以下是一个基于PyQt5的深度学习大豆病虫害智能检测系统核心程序:

import sys
import os
import cv2
import numpy as np
from pathlib import Path
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, 
                            QHBoxLayout, QLabel, QPushButton, QTextEdit, 
                            QFileDialog, QMessageBox, QProgressBar, QGroupBox,
                            QGridLayout, QScrollArea, QTabWidget, QTableWidget,
                            QTableWidgetItem, QHeaderView, QSplitter)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
from PyQt5.QtGui import QPixmap, QFont, QPalette, QColor

try:
    import tensorflow as tf
    from tensorflow import keras
    TENSORFLOW_AVAILABLE = True
except ImportError:
    TENSORFLOW_AVAILABLE = False

try:
    import torch
    import torchvision.transforms as transforms
    from PIL import Image
    PYTORCH_AVAILABLE = True
except ImportError:
    PYTORCH_AVAILABLE = False

class ImageProcessor:
    """图像处理类"""
    
    @staticmethod
    def preprocess_image(image_path, target_size=(224, 224)):
        """预处理图像"""
        try:
            # 读取图像
            image = cv2.imread(image_path)
            if image is None:
                return None
                
            # 转换颜色空间
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # 调整大小
            image = cv2.resize(image, target_size)
            
            # 归一化
            image = image.astype(np.float32) / 255.0
            
            return image
        except Exception as e:
            print(f"图像预处理错误: {e}")
            return None
    
    @staticmethod
    def enhance_image(image):
        """图像增强"""
        try:
            # 对比度增强
            lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
            l, a, b = cv2.split(lab)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
            l = clahe.apply(l)
            enhanced = cv2.merge([l, a, b])
            enhanced = cv2.cvtColor(enhanced, cv2.COLOR_LAB2RGB)
            return enhanced
        except:
            return image

class DiseaseDetectionThread(QThread):
    """病虫害检测线程"""
    progress_updated = pyqtSignal(int)
    result_ready = pyqtSignal(dict)
    error_occurred = pyqtSignal(str)
    
    def __init__(self, image_path, model_type="mock"):
        super().__init__()
        self.image_path = image_path
        self.model_type = model_type
        
    def run(self):
        """执行检测"""
        try:
            self.progress_updated.emit(10)
            
            # 预处理图像
            processor = ImageProcessor()
            image = processor.preprocess_image(self.image_path)
            
            if image is None:
                self.error_occurred.emit("图像预处理失败")
                return
                
            self.progress_updated.emit(30)
            
            # 模拟模型预测(实际应用中替换为真实模型)
            result = self.mock_prediction(image)
            
            self.progress_updated.emit(80)
            
            # 处理结果
            self.progress_updated.emit(90)
            
            self.result_ready.emit(result)
            self.progress_updated.emit(100)
            
        except Exception as e:
            self.error_occurred.emit(f"检测过程出错: {str(e)}")
    
    def mock_prediction(self, image):
        """模拟预测结果"""
        # 模拟病虫害类别和置信度
        diseases = [
            "大豆霜霉病", "大豆锈病", "大豆炭疽病", "大豆花叶病毒病",
            "大豆胞囊线虫病", "蚜虫", "豆天蛾", "豆荚螟", "健康"
        ]
        
        # 随机生成检测结果
        np.random.seed(hash(self.image_path) % 2**32)
        confidences = np.random.dirichlet(np.ones(len(diseases)), size=1)[0]
        
        # 排序并取前5个
        top_indices = np.argsort(confidences)[::-1][:5]
        
        results = []
        for i in top_indices:
            results.append({
                'disease': diseases[i],
                'confidence': float(confidences[i]),
                'severity': np.random.choice(['轻度', '中度', '重度']) if diseases[i] != "健康" else "无"
            })
        
        # 添加建议措施
        top_disease = results[0]['disease']
        recommendations = self.get_recommendations(top_disease)
        
        return {
            'predictions': results,
            'top_prediction': results[0],
            'recommendations': recommendations,
            'image_path': self.image_path
        }
    
    def get_recommendations(self, disease):
        """获取防治建议"""
        recommendations_dict = {
            "大豆霜霉病": [
                "加强田间通风透光",
                "合理密植,避免过密",
                "及时排除田间积水",
                "选用抗病品种",
                "喷施杀菌剂如甲霜灵"
            ],
            "大豆锈病": [
                "清除病残体",
                "轮作倒茬",
                "选用抗病品种",
                "适时喷施三唑酮等杀菌剂",
                "加强田间管理"
            ],
            "大豆炭疽病": [
                "种子消毒处理",
                "清洁田园",
                "合理施肥,增强抗性",
                "发病初期喷施百菌清",
                "避免连作"
            ],
            "大豆花叶病毒病": [
                "防治蚜虫传播",
                "选用抗病品种",
                "及时拔除病株",
                "种子处理消毒",
                "喷施病毒抑制剂"
            ],
            "大豆胞囊线虫病": [
                "轮作非寄主作物",
                "选用抗性品种",
                "土壤消毒处理",
                "改善土壤环境",
                "生物防治线虫"
            ],
            "蚜虫": [
                "及时发现及时治疗",
                "喷施吡虫啉等杀虫剂",
                "保护天敌昆虫",
                "黄色粘板诱杀",
                "加强田间监测"
            ],
            "豆天蛾": [
                "人工捕捉幼虫",
                "生物防治使用BT菌",
                "喷施高效氯氰菊酯",
                "清除杂草",
                "灯光诱杀成虫"
            ],
            "豆荚螟": [
                "清洁田园",
                "适时播种避开虫期",
                "喷施甲维盐等杀虫剂",
                "生物防治",
                "收获后深翻土壤"
            ],
            "健康": [
                "继续保持良好的田间管理",
                "定期监测病虫害",
                "合理施肥增强抗性",
                "保持田间清洁",
                "预防性喷施保护剂"
            ]
        }
        
        return recommendations_dict.get(disease, ["暂无具体建议,请咨询农技专家"])

class ResultDisplayWidget(QWidget):
    """检测结果显示组件"""
    
    def __init__(self):
        super().__init__()
        self.init_ui()
        
    def init_ui(self):
        layout = QVBoxLayout()
        
        # 结果表格
        self.result_table = QTableWidget()
        self.result_table.setColumnCount(3)
        self.result_table.setHorizontalHeaderLabels(['病虫害类型', '置信度', '严重程度'])
        self.result_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
        
        # 建议措施
        self.recommendations_text = QTextEdit()
        self.recommendations_text.setMaximumHeight(150)
        self.recommendations_text.setReadOnly(True)
        
        layout.addWidget(QLabel("检测结果:"))
        layout.addWidget(self.result_table)
        layout.addWidget(QLabel("防治建议:"))
        layout.addWidget(self.recommendations_text)
        
        self.setLayout(layout)
        
    def update_results(self, result_data):
        """更新检测结果"""
        predictions = result_data['predictions']
        recommendations = result_data['recommendations']
        
        # 更新表格
        self.result_table.setRowCount(len(predictions))
        for i, pred in enumerate(predictions):
            self.result_table.setItem(i, 0, QTableWidgetItem(pred['disease']))
            self.result_table.setItem(i, 1, QTableWidgetItem(f"{pred['confidence']:.2%}"))
            self.result_table.setItem(i, 2, QTableWidgetItem(pred['severity']))
            
        # 更新建议
        rec_text = "\n".join([f"• {rec}" for rec in recommendations])
        self.recommendations_text.setText(rec_text)

class ImageDisplayWidget(QWidget):
    """图像显示组件"""
    
    def __init__(self):
        super().__init__()
        self.init_ui()
        
    def init_ui(self):
        layout = QVBoxLayout()
        
        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setStyleSheet("border: 2px dashed #ccc; background-color: #f9f9f9;")
        self.image_label.setMinimumSize(400, 300)
        self.image_label.setText("请选择图像文件\n支持格式: JPG, PNG, BMP")
        
        # 滚动区域
        scroll_area = QScrollArea()
        scroll_area.setWidget(self.image_label)
        scroll_area.setWidgetResizable(True)
        
        layout.addWidget(scroll_area)
        self.setLayout(layout)
        
    def display_image(self, image_path):
        """显示图像"""
        try:
            pixmap = QPixmap(image_path)
            if not pixmap.isNull():
                # 缩放图像以适应显示区域
                scaled_pixmap = pixmap.scaled(400, 300, Qt.KeepAspectRatio, Qt.SmoothTransformation)
                self.image_label.setPixmap(scaled_pixmap)
                self.image_label.setText("")
            else:
                self.image_label.setText("无法加载图像")
        except Exception as e:
            self.image_label.setText(f"图像加载失败: {str(e)}")

class HistoryWidget(QWidget):
    """历史记录组件"""
    
    def __init__(self):
        super().__init__()
        self.history_data = []
        self.init_ui()
        
    def init_ui(self):
        layout = QVBoxLayout()
        
        # 历史记录表格
        self.history_table = QTableWidget()
        self.history_table.setColumnCount(4)
        self.history_table.setHorizontalHeaderLabels(['时间', '图像名称', '检测结果', '置信度'])
        self.history_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
        
        # 清空按钮
        clear_btn = QPushButton("清空历史")
        clear_btn.clicked.connect(self.clear_history)
        
        layout.addWidget(self.history_table)
        layout.addWidget(clear_btn)
        
        self.setLayout(layout)
        
    def add_record(self, result_data):
        """添加历史记录"""
        from datetime import datetime
        
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        image_name = os.path.basename(result_data['image_path'])
        top_prediction = result_data['top_prediction']
        
        record = {
            'timestamp': timestamp,
            'image_name': image_name,
            'disease': top_prediction['disease'],
            'confidence': top_prediction['confidence']
        }
        
        self.history_data.append(record)
        self.update_table()
        
    def update_table(self):
        """更新历史记录表格"""
        self.history_table.setRowCount(len(self.history_data))
        for i, record in enumerate(self.history_data):
            self.history_table.setItem(i, 0, QTableWidgetItem(record['timestamp']))
            self.history_table.setItem(i, 1, QTableWidgetItem(record['image_name']))
            self.history_table.setItem(i, 2, QTableWidgetItem(record['disease']))
            self.history_table.setItem(i, 3, QTableWidgetItem(f"{record['confidence']:.2%}"))
            
    def clear_history(self):
        """清空历史记录"""
        reply = QMessageBox.question(self, '确认', '确定要清空所有历史记录吗?',
                                   QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.history_data.clear()
            self.update_table()

class SoybeanDiseaseDetectionSystem(QMainWindow):
    """大豆病虫害智能检测系统主窗口"""
    
    def __init__(self):
        super().__init__()
        self.current_image_path = None
        self.detection_thread = None
        self.init_ui()
        self.setup_style()
        
    def init_ui(self):
        self.setWindowTitle("大豆病虫害智能检测系统 v1.0")
        self.setGeometry(100, 100, 1200, 800)
        
        # 创建中央窗口部件
        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        
        # 创建主布局
        main_layout = QHBoxLayout()
        
        # 创建分割器
        splitter = QSplitter(Qt.Horizontal)
        
        # 左侧面板 - 图像显示和控制
        left_panel = self.create_left_panel()
        splitter.addWidget(left_panel)
        
        # 右侧面板 - 结果显示
        right_panel = self.create_right_panel()
        splitter.addWidget(right_panel)
        
        # 设置分割器比例
        splitter.setSizes([500, 700])
        
        main_layout.addWidget(splitter)
        central_widget.setLayout(main_layout)
        
        # 创建状态栏
        self.statusBar().showMessage("就绪")
        
        # 创建进度条
        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        self.statusBar().addPermanentWidget(self.progress_bar)
        
    def create_left_panel(self):
        """创建左侧面板"""
        panel = QWidget()
        layout = QVBoxLayout()
        
        # 图像显示区域
        image_group = QGroupBox("图像预览")
        image_layout = QVBoxLayout()
        
        self.image_display = ImageDisplayWidget()
        image_layout.addWidget(self.image_display)
        image_group.setLayout(image_layout)
        
        # 控制按钮区域
        control_group = QGroupBox("操作控制")
        control_layout = QVBoxLayout()
        
        # 文件选择按钮
        self.select_btn = QPushButton("选择图像文件")
        self.select_btn.clicked.connect(self.select_image)
        
        # 检测按钮
        self.detect_btn = QPushButton("开始检测")
        self.detect_btn.clicked.connect(self.start_detection)
        self.detect_btn.setEnabled(False)
        
        # 保存结果按钮
        self.save_btn = QPushButton("保存结果")
        self.save_btn.clicked.connect(self.save_results)
        self.save_btn.setEnabled(False)
        
        control_layout.addWidget(self.select_btn)
        control_layout.addWidget(self.detect_btn)
        control_layout.addWidget(self.save_btn)
        control_group.setLayout(control_layout)
        
        # 系统信息
        info_group = QGroupBox("系统信息")
        info_layout = QVBoxLayout()
        
        info_text = QTextEdit()
        info_text.setMaximumHeight(100)
        info_text.setReadOnly(True)
        
        system_info = f"""
深度学习框架状态:
• TensorFlow: {'可用' if TENSORFLOW_AVAILABLE else '不可用'}
• PyTorch: {'可用' if PYTORCH_AVAILABLE else '不可用'}
• OpenCV: 可用
• 当前使用: 模拟模式
        """
        info_text.setText(system_info.strip())
        
        info_layout.addWidget(info_text)
        info_group.setLayout(info_layout)
        
        layout.addWidget(image_group)
        layout.addWidget(control_group)
        layout.addWidget(info_group)
        
        panel.setLayout(layout)
        return panel
        
    def create_right_panel(self):
        """创建右侧面板"""
        # 创建标签页组件
        tab_widget = QTabWidget()
        
        # 检测结果标签页
        self.result_display = ResultDisplayWidget()
        tab_widget.addTab(self.result_display, "检测结果")
        
        # 历史记录标签页
        self.history_widget = HistoryWidget()
        tab_widget.addTab(self.history_widget, "历史记录")
        
        # 帮助信息标签页
        help_widget = self.create_help_widget()
        tab_widget.addTab(help_widget, "使用说明")
        
        return tab_widget
        
    def create_help_widget(self):
        """创建帮助信息组件"""
        widget = QWidget()
        layout = QVBoxLayout()
        
        help_text = QTextEdit()
        help_text.setReadOnly(True)
        
        help_content = """
# 大豆病虫害智能检测系统使用说明

## 系统功能
本系统基于深度学习技术,能够智能识别大豆常见的病虫害,包括:

### 主要病害
• 大豆霜霉病
• 大豆锈病  
• 大豆炭疽病
• 大豆花叶病毒病
• 大豆胞囊线虫病

### 主要虫害
• 蚜虫
• 豆天蛾
• 豆荚螟

## 使用步骤
1. 点击"选择图像文件"按钮,选择需要检测的大豆叶片图像
2. 点击"开始检测"按钮,系统将自动分析图像
3. 查看"检测结果"标签页中的分析结果和防治建议
4. 可在"历史记录"标签页查看之前的检测记录
5. 使用"保存结果"按钮保存当前检测结果

## 注意事项
• 请确保图像清晰,光线充足
• 建议使用JPG、PNG格式的图像文件
• 单次检测时间约10-30秒,请耐心等待
• 系统提供的防治建议仅供参考,具体用药请咨询农技专家

## 技术支持
如遇到问题,请联系技术支持团队。
        """
        
        help_text.setMarkdown(help_content)
        
        layout.addWidget(help_text)
        widget.setLayout(layout)
        
        return widget
        
    def setup_style(self):
        """设置界面样式"""
        self.setStyleSheet("""
            QMainWindow {
                background-color: #f0f0f0;
            }
            QGroupBox {
                font-weight: bold;
                border: 2px solid #cccccc;
                border-radius: 8px;
                margin-top: 1ex;
                padding-top: 10px;
            }
            QGroupBox::title {
                subcontrol-origin: margin;
                left: 10px;
                padding: 0 5px 0 5px;
            }
            QPushButton {
                background-color: #4CAF50;
                color: white;
                border: none;
                padding: 8px;
                border-radius: 4px;
                font-size: 12px;
            }
            QPushButton:hover {
                background-color: #45a049;
            }
            QPushButton:disabled {
                background-color: #cccccc;
                color: #666666;
            }
            QTableWidget {
                gridline-color: #d0d0d0;
                background-color: white;
            }
            QTableWidget::item {
                padding: 5px;
            }
        """)
        
    def select_image(self):
        """选择图像文件"""
        file_dialog = QFileDialog()
        file_path, _ = file_dialog.getOpenFileName(
            self, "选择图像文件", "", 
            "图像文件 (*.jpg *.jpeg *.png *.bmp);;所有文件 (*)"
        )
        
        if file_path:
            self.current_image_path = file_path
            self.image_display.display_image(file_path)
            self.detect_btn.setEnabled(True)
            self.statusBar().showMessage(f"已选择文件: {os.path.basename(file_path)}")
            
    def start_detection(self):
        """开始病虫害检测"""
        if not self.current_image_path:
            QMessageBox.warning(self, "警告", "请先选择图像文件!")
            return
            
        # 禁用按钮
        self.detect_btn.setEnabled(False)
        self.select_btn.setEnabled(False)
        
        # 显示进度条
        self.progress_bar.setVisible(True)
        self.progress_bar.setValue(0)
        
        # 启动检测线程
        self.detection_thread = DiseaseDetectionThread(self.current_image_path)
        self.detection_thread.progress_updated.connect(self.update_progress)
        self.detection_thread.result_ready.connect(self.display_results)
        self.detection_thread.error_occurred.connect(self.handle_error)
        self.detection_thread.start()
        
        self.statusBar().showMessage("正在检测中...")
        
    def update_progress(self, value):
        """更新进度条"""
        self.progress_bar.setValue(value)
        
    def display_results(self, result_data):
        """显示检测结果"""
        # 更新结果显示
        self.result_display.update_results(result_data)
        
        # 添加到历史记录
        self.history_widget.add_record(result_data)
        
        # 重新启用按钮
        self.detect_btn.setEnabled(True)
        self.select_btn.setEnabled(True)
        self.save_btn.setEnabled(True)
        
        # 隐藏进度条
        self.progress_bar.setVisible(False)
        
        # 更新状态栏
        top_result = result_data['top_prediction']
        self.statusBar().showMessage(
            f"检测完成: {top_result['disease']} (置信度: {top_result['confidence']:.2%})"
        )
        
    def handle_error(self, error_message):
        """处理错误"""
        QMessageBox.critical(self, "检测错误", error_message)
        
        # 重新启用按钮
        self.detect_btn.setEnabled(True)
        self.select_btn.setEnabled(True)
        
        # 隐藏进度条
        self.progress_bar.setVisible(False)
        
        self.statusBar().showMessage("检测失败")
        
    def save_results(self):
        """保存检测结果"""
        if not hasattr(self.result_display, 'current_result'):
            QMessageBox.warning(self, "警告", "没有可保存的结果!")
            return
            
        file_dialog = QFileDialog()
        file_path, _ = file_dialog.getSaveFileName(
            self, "保存检测结果", "detection_result.txt", 
            "文本文件 (*.txt);;所有文件 (*)"
        )
        
        if file_path:
            try:
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write("大豆病虫害检测结果\n")
                    f.write("=" * 50 + "\n")
                    # 这里可以添加详细的结果保存逻辑
                    f.write("保存功能待完善\n")
                    
                QMessageBox.information(self, "成功", f"结果已保存到: {file_path}")
            except Exception as e:
                QMessageBox.critical(self, "错误", f"保存失败: {str(e)}")

def main():
    app = QApplication(sys.argv)
    app.setApplicationName("大豆病虫害智能检测系统")
    app.setApplicationVersion("1.0")
    
    # 设置应用样式
    app.setStyle('Fusion')
    
    window = SoybeanDiseaseDetectionSystem()
    window.show()
    
    sys.exit(app.exec_())

if __name__ == "__main__":
    main()

requirements.txt文件:

PyQt5>=5.15.0
opencv-python>=4.5.0
numpy>=1.19.0
pillow>=8.0.0
tensorflow>=2.6.0
torch>=1.9.0
torchvision>=0.10.0

系统功能特点

1. 智能检测功能

  • 支持9种常见大豆病虫害识别
  • 提供置信度和严重程度评估
  • 基于深度学习的图像分析

2. 用户界面

  • 直观的图形界面,易于操作
  • 实时显示检测进度
  • 分类显示检测结果和历史记录

3. 防治建议

  • 针对每种病虫害提供专业防治建议
  • 包括农药使用、栽培管理等多方面建议

4. 数据管理

  • 历史检测记录保存和查看
  • 支持结果导出功能
  • 图像预览和管理

使用说明

  1. 安装依赖:pip install -r requirements.txt
  2. 运行程序:python main.py
  3. 选择大豆叶片图像进行检测
  4. 查看检测结果和防治建议

注意:当前代码使用模拟检测结果。在实际实现中,需要集成训练好的深度学习模型(如CNN、ResNet等)来替换mock_prediction函数。

算法核心:

深度学习模型集成:

import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50, EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import pickle
import json
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt

class DataPreprocessor:
    """数据预处理类"""
    
    def __init__(self, image_size=(224, 224)):
        self.image_size = image_size
        self.mean = np.array([0.485, 0.456, 0.406])
        self.std = np.array([0.229, 0.224, 0.225])
        
    def load_and_preprocess_image(self, image_path):
        """加载并预处理单张图像"""
        try:
            # 读取图像
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"无法读取图像: {image_path}")
            
            # BGR转RGB
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # 调整尺寸
            image = cv2.resize(image, self.image_size)
            
            # 归一化到[0,1]
            image = image.astype(np.float32) / 255.0
            
            # 标准化
            image = (image - self.mean) / self.std
            
            return image
            
        except Exception as e:
            raise Exception(f"图像预处理失败: {str(e)}")
    
    def augment_image(self, image):
        """图像增强"""
        # 随机水平翻转
        if np.random.random() > 0.5:
            image = np.fliplr(image)
        
        # 随机亮度调整
        brightness_factor = np.random.uniform(0.8, 1.2)
        image = np.clip(image * brightness_factor, 0, 1)
        
        # 随机对比度调整
        contrast_factor = np.random.uniform(0.8, 1.2)
        image = np.clip((image - 0.5) * contrast_factor + 0.5, 0, 1)
        
        return image
    
    def create_data_generator(self, data_dir, batch_size=32, training=True):
        """创建数据生成器"""
        if training:
            datagen = ImageDataGenerator(
                rotation_range=20,
                width_shift_range=0.1,
                height_shift_range=0.1,
                shear_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True,
                fill_mode='nearest',
                rescale=1./255,
                preprocessing_function=self._preprocess_function
            )
        else:
            datagen = ImageDataGenerator(
                rescale=1./255,
                preprocessing_function=self._preprocess_function
            )
        
        generator = datagen.flow_from_directory(
            data_dir,
            target_size=self.image_size,
            batch_size=batch_size,
            class_mode='categorical'
        )
        
        return generator
    
    def _preprocess_function(self, image):
        """预处理函数"""
        # 标准化
        image = (image - self.mean) / self.std
        return image

class CNNModel:
    """卷积神经网络模型"""
    
    def __init__(self, num_classes=9, input_shape=(224, 224, 3)):
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.model = None
        self.class_names = [
            '大豆霜霉病', '大豆锈病', '大豆炭疽病', '大豆花叶病毒病',
            '大豆胞囊线虫病', '蚜虫', '豆天蛾', '豆荚螟', '健康'
        ]
        
    def build_custom_model(self):
        """构建自定义CNN模型"""
        model = models.Sequential([
            # 第一个卷积块
            layers.Conv2D(32, (3, 3), activation='relu', input_shape=self.input_shape),
            layers.BatchNormalization(),
            layers.MaxPooling2D((2, 2)),
            layers.Dropout(0.25),
            
            # 第二个卷积块
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.BatchNormalization(),
            layers.MaxPooling2D((2, 2)),
            layers.Dropout(0.25),
            
            # 第三个卷积块
            layers.Conv2D(128, (3, 3), activation='relu'),
            layers.BatchNormalization(),
            layers.MaxPooling2D((2, 2)),
            layers.Dropout(0.25),
            
            # 第四个卷积块
            layers.Conv2D(256, (3, 3), activation='relu'),
            layers.BatchNormalization(),
            layers.MaxPooling2D((2, 2)),
            layers.Dropout(0.25),
            
            # 全连接层
            layers.Flatten(),
            layers.Dense(512, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.5),
            layers.Dense(256, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.5),
            layers.Dense(self.num_classes, activation='softmax')
        ])
        
        return model
    
    def build_transfer_learning_model(self, base_model_name='resnet50'):
        """构建迁移学习模型"""
        if base_model_name == 'resnet50':
            base_model = ResNet50(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
        elif base_model_name == 'efficientnet':
            base_model = EfficientNetB0(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
        else:
            raise ValueError(f"不支持的基础模型: {base_model_name}")
        
        # 冻结预训练层
        base_model.trainable = False
        
        # 添加自定义分类头
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.BatchNormalization(),
            layers.Dropout(0.5),
            layers.Dense(512, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.3),
            layers.Dense(256, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.3),
            layers.Dense(self.num_classes, activation='softmax')
        ])
        
        return model
    
    def compile_model(self, model, learning_rate=0.001):
        """编译模型"""
        optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', 'top_2_accuracy']
        )
        return model
    
    def train_model(self, model, train_generator, validation_generator, epochs=50):
        """训练模型"""
        # 回调函数
        callbacks = [
            keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=10,
                restore_best_weights=True
            ),
            keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=5,
                min_lr=1e-7
            ),
            keras.callbacks.ModelCheckpoint(
                'best_model.h5',
                monitor='val_accuracy',
                save_best_only=True,
                save_weights_only=False
            )
        ]
        
        # 训练模型
        history = model.fit(
            train_generator,
            epochs=epochs,
            validation_data=validation_generator,
            callbacks=callbacks,
            verbose=1
        )
        
        return history
    
    def fine_tune_model(self, model, train_generator, validation_generator, epochs=20):
        """微调模型"""
        # 解冻部分层
        if hasattr(model.layers[0], 'trainable'):
            model.layers[0].trainable = True
            
        # 使用更小的学习率
        model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=0.0001/10),
            loss='categorical_crossentropy',
            metrics=['accuracy', 'top_2_accuracy']
        )
        
        # 微调训练
        history = model.fit(
            train_generator,
            epochs=epochs,
            validation_data=validation_generator,
            verbose=1
        )
        
        return history

class PyTorchModel(nn.Module):
    """PyTorch模型实现"""
    
    def __init__(self, num_classes=9):
        super(PyTorchModel, self).__init__()
        self.num_classes = num_classes
        
        # 卷积层
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 全连接层
        self.fc1 = nn.Linear(256, 512)
        self.bn5 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn6 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, num_classes)
        
        # Dropout
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        # 卷积块1
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = F.dropout(x, 0.25, training=self.training)
        
        # 卷积块2
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = F.dropout(x, 0.25, training=self.training)
        
        # 卷积块3
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = F.dropout(x, 0.25, training=self.training)
        
        # 卷积块4
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        x = F.dropout(x, 0.25, training=self.training)
        
        # 自适应池化
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        
        # 全连接层
        x = F.relu(self.bn5(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn6(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

class ModelTrainer:
    """模型训练器"""
    
    def __init__(self, framework='tensorflow'):
        self.framework = framework
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    def train_tensorflow_model(self, data_dir, model_type='transfer'):
        """训练TensorFlow模型"""
        # 数据预处理
        preprocessor = DataPreprocessor()
        
        # 创建数据生成器
        train_generator = preprocessor.create_data_generator(
            os.path.join(data_dir, 'train'), 
            batch_size=32, 
            training=True
        )
        
        val_generator = preprocessor.create_data_generator(
            os.path.join(data_dir, 'validation'), 
            batch_size=32, 
            training=False
        )
        
        # 创建模型
        cnn_model = CNNModel()
        if model_type == 'custom':
            model = cnn_model.build_custom_model()
        else:
            model = cnn_model.build_transfer_learning_model()
        
        # 编译模型
        model = cnn_model.compile_model(model)
        
        # 训练模型
        history = cnn_model.train_model(model, train_generator, val_generator)
        
        # 微调(仅适用于迁移学习)
        if model_type == 'transfer':
            history_fine = cnn_model.fine_tune_model(model, train_generator, val_generator)
        
        # 保存模型
        model.save('soybean_disease_model.h5')
        
        return model, history
    
    def train_pytorch_model(self, data_dir, epochs=50, batch_size=32, learning_rate=0.001):
        """训练PyTorch模型"""
        # 数据变换
        transform_train = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        transform_val = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # 数据加载器
        from torchvision.datasets import ImageFolder
        from torch.utils.data import DataLoader
        
        train_dataset = ImageFolder(
            os.path.join(data_dir, 'train'), 
            transform=transform_train
        )
        val_dataset = ImageFolder(
            os.path.join(data_dir, 'validation'), 
            transform=transform_val
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # 创建模型
        model = PyTorchModel(num_classes=len(train_dataset.classes))
        model.to(self.device)
        
        # 损失函数和优化器
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
        
        # 训练循环
        train_losses = []
        val_accuracies = []
        
        for epoch in range(epochs):
            # 训练阶段
            model.train()
            running_loss = 0.0
            
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
            
            # 验证阶段
            model.eval()
            correct = 0
            total = 0
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            val_accuracy = 100 * correct / total
            train_loss = running_loss / len(train_loader)
            
            train_losses.append(train_loss)
            val_accuracies.append(val_accuracy)
            
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {train_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
            
            scheduler.step()
        
        # 保存模型
        torch.save(model.state_dict(), 'soybean_disease_pytorch.pth')
        
        return model, train_losses, val_accuracies

class DiseasePredictor:
    """病虫害预测器"""
    
    def __init__(self, model_path=None, framework='tensorflow'):
        self.framework = framework
        self.model = None
        self.class_names = [
            '大豆霜霉病', '大豆锈病', '大豆炭疽病', '大豆花叶病毒病',
            '大豆胞囊线虫病', '蚜虫', '豆天蛾', '豆荚螟', '健康'
        ]
        self.preprocessor = DataPreprocessor()
        
        if model_path and os.path.exists(model_path):
            self.load_model(model_path)
    
    def load_model(self, model_path):
        """加载预训练模型"""
        try:
            if self.framework == 'tensorflow':
                self.model = keras.models.load_model(model_path)
            elif self.framework == 'pytorch':
                self.model = PyTorchModel()
                self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
                self.model.eval()
            print(f"模型加载成功: {model_path}")
        except Exception as e:
            print(f"模型加载失败: {str(e)}")
            self.model = None
    
    def predict_single_image(self, image_path):
        """预测单张图像"""
        if self.model is None:
            return self._mock_prediction(image_path)
        
        try:
            # 预处理图像
            image = self.preprocessor.load_and_preprocess_image(image_path)
            
            if self.framework == 'tensorflow':
                # TensorFlow预测
                image_batch = np.expand_dims(image, axis=0)
                predictions = self.model.predict(image_batch)[0]
                
            elif self.framework == 'pytorch':
                # PyTorch预测
                transform = transforms.Compose([
                    transforms.ToPILImage(),
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                       std=[0.229, 0.224, 0.225])
                ])
                
                # 反归一化后转换
                image_denorm = (image * self.preprocessor.std + self.preprocessor.mean) * 255
                image_denorm = image_denorm.astype(np.uint8)
                image_tensor = transform(image_denorm).unsqueeze(0)
                
                with torch.no_grad():
                    outputs = self.model(image_tensor)
                    predictions = torch.softmax(outputs, dim=1)[0].numpy()
            
            # 处理预测结果
            return self._process_predictions(predictions, image_path)
            
        except Exception as e:
            print(f"预测出错: {str(e)}")
            return self._mock_prediction(image_path)
    
    def _process_predictions(self, predictions, image_path):
        """处理预测结果"""
        # 获取置信度排序
        sorted_indices = np.argsort(predictions)[::-1]
        
        results = []
        for i in sorted_indices[:5]:  # 取前5个结果
            confidence = predictions[i]
            disease = self.class_names[i]
            severity = self._determine_severity(disease, confidence)
            
            results.append({
                'disease': disease,
                'confidence': float(confidence),
                'severity': severity
            })
        
        # 获取防治建议
        top_disease = results[0]['disease']
        recommendations = self._get_recommendations(top_disease)
        
        return {
            'predictions': results,
            'top_prediction': results[0],
            'recommendations': recommendations,
            'image_path': image_path
        }
    
    def _determine_severity(self, disease, confidence):
        """确定严重程度"""
        if disease == '健康':
            return '无'
        
        if confidence > 0.8:
            return '重度'
        elif confidence > 0.6:
            return '中度'
        else:
            return '轻度'
    
    def _mock_prediction(self, image_path):
        """模拟预测(用于演示)"""
        np.random.seed(hash(image_path) % 2**32)
        predictions = np.random.dirichlet(np.ones(len(self.class_names)), size=1)[0]
        return self._process_predictions(predictions, image_path)
    
    def _get_recommendations(self, disease):
        """获取防治建议"""
        recommendations_dict = {
            "大豆霜霉病": [
                "选用抗病品种如合丰25、垦丰16等",
                "合理密植,改善田间通风透光条件",
                "及时排水,降低田间湿度",
                "发病初期喷施58%甲霜灵·锰锌可湿性粉剂500倍液",
                "收获后深翻土壤,清除病残体"
            ],
            "大豆锈病": [
                "选择抗锈病品种",
                "合理轮作,避免连作",
                "发病初期喷施15%三唑酮可湿性粉剂1000倍液",
                "也可使用25%丙环唑乳油2000倍液",
                "收获后及时清理病残体并深翻土壤"
            ],
            "大豆炭疽病": [
                "选用无病种子或进行种子消毒",
                "实行2-3年轮作制",
                "增施钾肥,提高植株抗病性",
                "发病初期喷施70%甲基硫菌灵可湿性粉剂1000倍液",
                "收获后清洁田园,减少病原菌来源"
            ],
            "大豆花叶病毒病": [
                "选用抗病毒病品种",
                "防治蚜虫等传毒媒介昆虫",
                "及时拔除病株并带出田外销毁",
                "使用10%吡虫啉可湿性粉剂3000倍液防治蚜虫",
                "播种前进行种子处理"
            ],
            "大豆胞囊线虫病": [
                "选用抗线虫品种",
                "实行轮作,种植玉米、小麦等非寄主作物",
                "土壤深翻晒垡,降低土中线虫密度",
                "使用阿维菌素颗粒剂进行土壤处理",
                "改良土壤,增施有机肥"
            ],
            "蚜虫": [
                "及早发现,及时防治",
                "使用10%吡虫啉可湿性粉剂3000倍液喷施",
                "也可用2.5%溴氰菊酯乳油2000倍液",
                "保护天敌昆虫如瓢虫、草蛉等",
                "使用黄色粘虫板进行物理防治"
            ],
            "豆天蛾": [
                "人工捕杀幼虫",
                "使用黑光灯诱杀成虫",
                "喷施苏云金杆菌(BT)进行生物防治",
                "也可使用4.5%高效氯氰菊酯乳油1500倍液",
                "清除田边杂草,减少越冬场所"
            ],
            "豆荚螟": [
                "秋翻春耙,破坏越冬蛹室",
                "在成虫羽化盛期使用性诱剂诱杀",
                "在卵孵化盛期喷施25%杀虫双水剂500倍液",
                "也可用2.5%溴氰菊酯乳油2000倍液",
                "收获后及时清理豆荚残体"
            ],
            "健康": [
                "继续保持良好的栽培管理措施",
                "定期进行田间巡查,监测病虫害发生情况",
                "适时适量施肥,增强植株抗病能力",
                "保持田间清洁,及时清理杂草和病残体",
                "预防性喷施保护性杀菌剂"
            ]
        }
        
        return recommendations_dict.get(disease, ["请咨询当地农技专家获取具体防治建议"])

class ModelEvaluator:
    """模型评估器"""
    
    def __init__(self, model, class_names):
        self.model = model
        self.class_names = class_names
    
    def evaluate_model(self, test_generator):
        """评估模型性能"""
        # 获取预测结果
        predictions = self.model.predict(test_generator)
        predicted_classes = np.argmax(predictions, axis=1)
        
        # 获取真实标签
        true_classes = test_generator.classes
        
        # 计算混淆矩阵
        cm = confusion_matrix(true_classes, predicted_classes)
        
        # 生成分类报告
        report = classification_report(
            true_classes, 
            predicted_classes, 
            target_names=self.class_names,
            output_dict=True
        )
        
        return {
            'confusion_matrix': cm,
            'classification_report': report,
            'accuracy': report['accuracy']
        }
    
    def plot_confusion_matrix(self, cm):
        """绘制混淆矩阵"""
        plt.figure(figsize=(10, 8))
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion Matrix')
        plt.colorbar()
        tick_marks = np.arange(len(self.class_names))
        plt.xticks(tick_marks, self.class_names, rotation=45)
        plt.yticks(tick_marks, self.class_names)
        
        thresh = cm.max() / 2.
        for i, j in np.ndindex(cm.shape):
            plt.text(j, i, format(cm[i, j], 'd'),
                    horizontalalignment="center",
                    color="white" if cm[i, j] > thresh else "black")
        
        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        plt.show()

# 使用示例
def main():
    # 初始化预测器
    predictor = DiseasePredictor(framework='tensorflow')
    
    # 如果有预训练模型,加载它
    # predictor.load_model('soybean_disease_model.h5')
    
    # 预测示例
    image_path = "sample_soybean_leaf.jpg"
    result = predictor.predict_single_image(image_path)
    
    print("检测结果:")
    for i, pred in enumerate(result['predictions']):
        print(f"{i+1}. {pred['disease']}: {pred['confidence']:.2%} ({pred['severity']})")
    
    print(f"\n防治建议:")
    for rec in result['recommendations']:
        print(f"• {rec}")

if __name__ == "__main__":
    main()

这个完整的算法实现包含了以下核心组件:

1. 数据预处理 (DataPreprocessor)

  • 图像标准化和归一化
  • 数据增强(翻转、旋转、亮度调整等)
  • 批量数据生成器

2. 深度学习模型

  • CNN模型: 自定义卷积神经网络
  • 迁移学习: 基于ResNet50/EfficientNet的预训练模型
  • PyTorch实现: 对应的PyTorch版本模型

3. 模型训练 (ModelTrainer)

  • TensorFlow/Keras训练流程
  • PyTorch训练流程
  • 早停、学习率调度等训练策略

4. 病害预测 (DiseasePredictor)

  • 单图像预测接口
  • 结果后处理
  • 置信度分析和严重程度判断

5. 模型评估 (ModelEvaluator)

  • 性能指标计算
  • 混淆矩阵可视化
  • 分类报告生成
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐