PyTorch Lightning入門

  • はじめに PythonのPyTorch Lightningについて学びましょう。PyTorch Lightningは、PyTorchの高レベルAPIで、ディープラーニングの開発をさらに簡単にします。では、二人の説明を見ていきましょう。
Gal Normal

PyTorch Lightningってなに?

Geek Curious

PyTorchの高レベルAPIで、ディープラーニングの開発をもっと簡単にするものだよね?

Gal Happy

そうなの!PyTorch Lightningを使うと、ディープラーニングの モデル の構築や 学習評価 など、よりシンプルにコードを書くことができるの!

Gal Pleased

まずは、PyTorch Lightningを使ってみよう!次のようにインポートして、使えるようにしよう!

import pytorch_lightning as pl
Geek Happy

わかった! import pytorch_lightning as pl でPyTorch Lightningをインポートして使えるようにするんだね!

Gal Happy

そうなの!簡単な例として、ニューラルネットワークのモデルを作ってみよう!

Gal Pleased

例えば、こんな感じで、PyTorch Lightningの LightningModule を継承してモデルを作成できるの!

import torch.nn as nn
import torch.optim as optim

class MyModel(pl.LightningModule):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(8, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log('train_loss', loss)
        return loss
Geek Happy

なるほど、 pl.LightningModule を継承したクラスを作成して、ネットワークの構造を定義するんだね!

Gal Happy

そうなの!PyTorch Lightningを使うと、ディープラーニングのモデル構築や学習が簡単にできるの!

  • おわりに PythonのPyTorch Lightningでは、ディープラーニングがさらに簡単に行えます。これで、ニューラルネットワークのモデル構築や学習、評価などがシンプルに書けるようになりましたね!😄