从零搭建深度学习开发环境:Ubuntu+Conda+TensorFlow+GPU+PyCharm

作者:菠萝爱吃肉2024.08.29 04:12浏览量:18

简介:本文指导非专业读者如何在Ubuntu系统上,利用Conda管理Python环境,安装支持GPU的TensorFlow,并通过PyCharm IDE高效进行深度学习开发。从环境准备到项目配置,每一步都详细讲解,助力初学者快速上手。

引言

深度学习领域,一个高效、稳定的开发环境是项目成功的关键。本文将详细介绍如何在Ubuntu操作系统上,通过Conda管理Python环境,安装支持GPU加速的TensorFlow库,并使用PyCharm作为集成开发环境(IDE)来搭建一个完整的深度学习开发环境。无论你是机器学习爱好者还是准备进入AI领域的开发者,本文都将为你提供实用的指导。

第一步:安装Ubuntu系统

假设你已经安装了Ubuntu系统。如果没有,可以从Ubuntu官网下载ISO镜像,并使用U盘启动盘制作工具(如Rufus、Etcher等)制作启动盘,然后按照提示安装Ubuntu。

第二步:安装NVIDIA驱动

由于我们将使用GPU加速,首先需要确保你的Ubuntu系统安装了正确的NVIDIA显卡驱动。可以通过以下步骤安装:

  1. 添加Ubuntu的图形驱动PPA

    1. sudo add-apt-repository ppa:graphics-drivers/ppa
    2. sudo apt update
  2. 安装推荐的NVIDIA驱动
    使用ubuntu-drivers devices查看推荐的驱动版本,然后使用sudo apt install nvidia-xxx(xxx为版本号)安装。

  3. 重启系统

    1. sudo reboot

第三步:安装Conda

Conda是一个开源的包、依赖和环境管理器,非常适合管理Python环境。

  1. 下载并安装Miniconda
    访问Miniconda官网下载对应Ubuntu版本的Miniconda安装脚本。

  2. 安装Miniconda
    打开终端,切换到下载目录,执行安装脚本:

    1. bash Miniconda3-latest-Linux-x86_64.sh

    按照提示完成安装。

  3. 初始化Conda
    在终端执行source ~/.bashrc或重新打开终端,使Conda生效。

第四步:创建并激活Python环境

使用Conda创建一个新的Python环境,并安装TensorFlow(GPU版)。

  1. 创建环境

    1. conda create -n tf-gpu python=3.8
  2. 激活环境

    1. conda activate tf-gpu
  3. 安装TensorFlow-GPU

    1. conda install tensorflow-gpu cudatoolkit=11.0 cudnn=8.0.4 -c conda-forge

    注意:这里安装的CUDA和cuDNN版本需与你的NVIDIA驱动兼容。

第五步:安装PyCharm

PyCharm是一个强大的Python IDE,支持代码调试、版本控制等。

  1. 下载PyCharm
    访问PyCharm官网下载PyCharm的Community或Professional版本。

  2. 解压并运行
    将下载的tar.gz文件解压到合适的位置,然后运行bin/pycharm.sh启动PyCharm。

  3. 配置Python解释器
    在PyCharm中,通过File > Settings > Project: your_project_name > Python Interpreter,选择之前通过Conda创建的tf-gpu环境。

第六步:测试环境

编写一个简单的TensorFlow程序来测试环境是否搭建成功。

  1. import tensorflow as tf
  2. # 检查TensorFlow是否使用了GPU
  3. print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
  4. # 创建一个简单的TensorFlow模型
  5. model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
  6. model.compile(optimizer='sgd', loss='mean_squared_error')
  7. # 打印模型摘要
  8. model.summary()