PyTorch 常用語法統整

#PyTorch #Deep Learning #Neural Networks #Machine Learning #Python #Tensors #CUDA #GPU #Dataset #DataLoader #Loss Functions #Optimizers #Training

Table of Contents

安裝 PyTorch

安裝 CPU 版本

若沒有 GPU 或者只需要在 CPU 上執行,可以使用以下命令安裝 PyTorch:

pip install torch torchvision torchaudio

安裝 GPU 版本

若有 NVIDIA GPU,建議使用 CUDA 來加速運算,請根據你的 CUDA 版本選擇對應的安裝指令。 可以前往官方網站 PyTorch 官網 獲取最新的安裝命令。

例如,對於 CUDA 11.8 版本,使用以下命令:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

對於 CUDA 12.1 版本,使用:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

如果不確定 CUDA 版本,可以使用以下命令檢查:

nvcc --version

引入 PyTorch

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
  • torch:PyTorch 的核心庫,負責張量(Tensor)操作、計算圖、GPU 加速等功能。
  • torch.nn:定義神經網路結構,包含全連接層、卷積層、LSTM 層等。
  • torch.optim:提供各種優化器,如隨機梯度下降(SGD)、Adam、RMSprop 等。
  • torch.nn.functional:包含許多不帶權重的函數,如激活函數(ReLU, Sigmoid)、池化(MaxPool2d)等。
  • torchvision.transforms:用於影像數據的增強與標準化,適用於計算機視覺任務。
  • torchvision.datasets:內建多個常見影像數據集,如 MNIST、CIFAR-10,可直接下載使用。
  • torch.utils.data.Dataset:自定義數據集的基類,允許我們定義如何讀取數據。
  • torch.utils.data.DataLoader:批量加載數據,支持多線程加載並打亂數據。

檢查 GPU 是否可用

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

這段代碼會檢查 GPU 是否可用,若不可用則回退至 CPU。

建立 Tensor

tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
print(tensor)

Tensor 是 PyTorch 的核心數據結構,類似於 NumPy 的陣列,但可以加速計算。

將 Tensor 移動到 GPU

tensor = tensor.to(device)

使用 .to(device) 方法可以將 Tensor 移動到指定設備(CPU 或 GPU)。

自訂 Dataset

class CustomDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(100, 2)  # 生成 100 筆隨機數據,每筆 2 維
        self.labels = torch.randn(100, 1)  # 生成 100 筆隨機標籤
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

自訂 Dataset 允許我們定義數據的存儲與提取方式。 這是一個簡單的自訂 Dataset,生成 100 筆隨機數據,每筆資料有 2 個特徵和 1 個標籤。

建立 DataLoader

dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

DataLoader 用於批量加載數據,batch_size=10 代表每批 10 筆數據,shuffle=True 讓數據順序隨機打亂。

遍歷 DataLoader

for batch in dataloader:
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = model(inputs)
    print(outputs)
    break  # 只展示一批數據

透過 DataLoader 批量加載數據,並送入模型進行預測。


建立簡單神經網路

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(2, 4)  # 第一層線性層,輸入 2 個特徵,輸出 4 個特徵
        self.fc2 = nn.Linear(4, 1)  # 第二層線性層,輸入 4 個特徵,輸出 1 個特徵
    
    def forward(self, x):
        x = F.relu(self.fc1(x))  # 使用 ReLU 激活函數
        x = self.fc2(x)  # 輸出層
        return x

這是一個簡單的神經網路,包含兩層全連接層(Linear),第一層輸入 2 維數據並輸出 4 維,然後通過 ReLU 激活函數,最後輸出 1 個值。

建立模型

model = SimpleNN().to(device)
print(model)

將模型移動到 GPU(如果可用),並印出模型架構。

損失函數與優化器

criterion = nn.MSELoss()  # 均方誤差損失函數
optimizer = optim.Adam(model.parameters(), lr=0.01)  # 使用 Adam 優化器

損失函數用來衡量模型預測與實際值之間的誤差,優化器負責調整模型參數來最小化損失。

損失函數(Loss Function)

損失函數用來衡量模型的預測結果與實際值之間的誤差,選擇合適的損失函數對模型的訓練效果至關重要。

回歸問題常用的損失函數

損失函數說明
nn.MSELoss()均方誤差(Mean Squared Error, MSE),計算預測值與真值之間平方誤差的平均值,適用於回歸問題。
nn.L1Loss()絕對誤差(Mean Absolute Error, MAE),計算預測值與真值之間絕對誤差的平均值,比 MSE 更不容易受極端值影響。
nn.SmoothL1Loss()平滑的 L1 損失,結合了 L1 和 L2 損失的優點,在小誤差時類似於 MSE,大誤差時類似於 MAE,常用於機器學習中的魯棒回歸。

分類問題常用的損失函數

損失函數說明
nn.CrossEntropyLoss()交叉熵損失(Cross-Entropy Loss),適用於多類分類問題,內部包含 Softmax 層,因此輸入應為 raw logits(未經 softmax 處理)。
nn.NLLLoss()負對數似然損失(Negative Log Likelihood Loss),通常與 nn.LogSoftmax() 一起使用,適用於多類分類問題。
nn.BCELoss()二元交叉熵(Binary Cross Entropy Loss),適用於二元分類問題,要求輸出值經過 sigmoid 處理。
nn.BCEWithLogitsLoss()BCELoss 的變體,內部包含 sigmoid 操作,因此可以直接使用 raw logits 輸入,提高數值穩定性。

優化器(Optimizer)

優化器負責調整模型參數,以最小化損失函數的值。

常用的優化器

優化器說明
optim.SGD(model.parameters(), lr=0.01, momentum=0.9)隨機梯度下降(Stochastic Gradient Descent, SGD),可選 momentum 參數來加速收斂。
optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))Adam 優化器,結合了 SGDRMSprop 的優點,適用於大多數深度學習問題。
optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)RMSprop 優化器,適用於遞歸神經網路(RNN)和非平穩數據的學習。
optim.AdamW(model.parameters(), lr=0.001)AdamW 優化器,相較於 Adam,額外加入了權重衰減(Weight Decay),適用於 Transformer 等模型。
optim.Adagrad(model.parameters(), lr=0.01)自適應梯度(Adagrad),適用於稀疏數據的學習。

如何選擇損失函數與優化器?

  • 回歸問題:通常使用 MSELossSmoothL1Loss,優化器可以選擇 AdamSGD
  • 二元分類問題:通常使用 BCELoss(如果輸入未經 sigmoid 則使用 BCEWithLogitsLoss),優化器可以選擇 AdamSGD
  • 多類分類問題:通常使用 CrossEntropyLoss,優化器可以選擇 AdamSGD
  • 深度學習大模型(如 Transformer):優化器可以選擇 AdamW,以減少 L2 正則化帶來的不良影響。

訓練模型

num_epochs = 20  # 訓練 20 個 Epochs
for epoch in range(num_epochs):  # 迴圈 20 次,每次稱為一個 epoch
    running_loss = 0.0  # 用來累加每個 batch 的 loss,計算 epoch 平均 loss
    
    for batch in dataloader:  # 遍歷 DataLoader,每次獲取一批數據
        inputs, targets = batch  # 取得輸入數據 (inputs) 和標籤 (targets)
        inputs, targets = inputs.to(device), targets.to(device)  # 移動到 GPU 或 CPU

        # 前向傳播(Forward Pass)
        outputs = model(inputs)  # 模型輸入數據,產生預測值
        loss = criterion(outputs, targets)  # 計算預測值與真值之間的誤差
        
        # 反向傳播與優化(Backward Pass & Optimization)
        optimizer.zero_grad()  # 清空梯度,避免累積
        loss.backward()  # 反向傳播,計算梯度
        optimizer.step()  # 更新模型參數
        
        running_loss += loss.item()  # 累加 loss 以計算平均損失
    
    # 輸出當前 epoch 的平均 loss
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(dataloader):.4f}")

這段程式碼將模型訓練 20 個 epochs,每個 epoch 會遍歷整個資料集,計算損失並使用反向傳播來更新模型參數。

Disclaimer: All reference materials on this website are sourced from the internet and are intended for learning purposes only. If you believe any content infringes upon your rights, please contact me at csnote.cc@gmail.com, and I will remove the relevant content promptly.


Feedback Welcome: If you notice any errors or areas for improvement in the articles, I warmly welcome your feedback and corrections. Your input will help this blog provide better learning resources. This is an ongoing process of learning and improvement, and your suggestions are valuable to me. You can reach me at csnote.cc@gmail.com.