【1ビットの時代 – MNIST編】BitNet b1.58とBitNet(1-bit)を比較

 

【動画で解説】MNIST編 – BitNet比較

 

 

【1ビットの時代 – MNIST編】
BitNet b1.58(1.58-bit)・ BitNet(1-bit)・Denseの学習結果を比較してみよう


視聴時間:8分3秒

文字情報だけではわかりにくい場合に、BitNetの解説動画をご活用いただけますと幸いです。

 




 

【動画の内容】

0:00 はじめに
0:41 Google Colaboratoryの使い方
1:24 事前準備
2:53 MNISTで学習の比較
6:36 おわりに

 

 

MNIST編:BitNetで実験

 

 

BitLinearのビット数の違いなどを比較できるコード

pocokhc/BitNet(The MIT license)| GitHub
注: BitNet/BitNet b1.58のコードは、マイクロソフト公式では公開されていません。論文の情報をもとに、非公式でオープンソースの開発が行われている段階です。

を公開してくださっている日本人の方がいましたので、プログラムを試したところ面白かったので、すぐに実験できるようにGoogle ColaboratoryのCPUで試せるようにノートブックを作成しました。
ビット数の違いよる変化などについて検証したい日本人のAI初学者の方が、BitNet関連の機械学習プログラムに触れるきっかけになることがありましたら幸いです。
CPUでも気軽に試せるところがいいですね。

 

 

チュートリアルコードリンク・プログラムのライセンス

 

 

Google Colaboratoryのチュートリアルコード:
BitNet-Comparison-MNIST.ipynb(The MIT License)| Google Colaboratory

 

チュートリアルコード「BitNet-Comparison-MNIST.ipynb」のライセンス:

The MIT License

Copyright 2024 child programmer

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

 

 

【MNIST編】
BitNet比較をしてみよう – 最終更新:2024年3月16日

 

 

BitNet/BitNet b1.58などで手書き数字画像の画像認識精度の比較をしてみましょう!

 

 

【事前準備①】Pythonのバージョンを指定

 

 

Python3.12系をインストールし、このノートブックで使えるように有効化します。

【開発者の指定】

・Python 12系(3.12.2)
・PyTorch 2.2.1

実行コード

print('(このコード実行前のPythonのバージョンを確認)')
!python --version
print('\n\n')

# Python3.12系をインストール
!sudo apt install python3.12

# Python3.12系を有効化
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1

# pipをインストール
!sudo apt install python3-pip

# 現在のPythonのバージョンを確認
print('\n\n(現在利用中のPythonのバージョンを確認)')
!python --version

# PyTorchのバージョンの確認
print('\n\n(実行時点のバージョンです)') #2024年3月15日実行時点:PyTprch2.2.1+cu121
import torch
print('PyTprch' + torch.__version__)

 

出力結果

(このコード実行前のPythonのバージョンを確認)
Python 3.10.12

〜

(現在利用中のPythonのバージョンを確認)
Python 3.12.2


(実行時点のバージョンです)
PyTprch2.2.1+cu121

 

 

【事前準備②】BitNetのクローン

 

 

BitNetでMNISTの学習結果の比較ができる機械学習のサンプルコードを公開してくださっている以下のGitHubリボジトリ

pocokhc/BitNet(MIT license)| GitHub

をクローンします。

実行コード

!git clone https://github.com/pocokhc/BitNet
%cd BitNet

 

出力結果

Cloning into 'BitNet'...
remote: Enumerating objects: 18, done.
remote: Counting objects: 100% (18/18), done.
remote: Compressing objects: 100% (14/14), done.
remote: Total 18 (delta 3), reused 15 (delta 3), pack-reused 0
Receiving objects: 100% (18/18), 8.87 KiB | 8.87 MiB/s, done.
Resolving deltas: 100% (3/3), done.
/content/BitNet

 

 

MNISTで学習の比較

 

 

様々なニューラルネットワーク

①「BitLinearなし」
②「BitLinear 1ビット」(BitNet)
③「BitLinear 1.58ビット」(BitNet b1.58)

でMNISTを学習させた結果を比較してみましょう。

学習回数」を指定後に「実行コード」を実行します。
(初期設定の20回で違いがわかるのではないかと思います)
学習は、①②③の順番で実行され、最後に学習結果のグラフが表示されます。

実行コード


import itertools
import os
import sys
import time

import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision import datasets, transforms

#sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../bitnet"))

from bitnet.bitnet_torch import BitLinear

# use CPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


def _train_mnist(
    model: torch.nn.Module,
    epochs: int,
    batch_size: int,
    lr: float,
):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST("data", train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST("data", train=False, download=True, transform=transform), batch_size=batch_size, shuffle=False
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = []
    times = []

    print("train start")
    model.train()
    for epoch in range(epochs):
        t0 = time.time()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            y = model(images)
            loss = criterion(y, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        times.append(time.time() - t0)
        accuracy = _valid_model(model, test_loader)
        history.append(accuracy)
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.5f}, Acc: {accuracy:.3%}")
    total_times = list(itertools.accumulate(times))

    accuracy = _valid_model(model, test_loader)
    return history, total_times, accuracy


def _valid_model(model, test_loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total


def simple_test():
    model = nn.Sequential(
        nn.Flatten(),
        BitLinear(28 * 28, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
    )
    print(model)
    _, _, score = _train_mnist(model, 5, 512, 0.001)
    print(f"Accuracy on test data: {score:.3%}")  # Accuracy on test data: 96.00%


def compare(units: int, layers: int, lr: float, epochs: int):
    models = []

    # --- dense
    m = nn.Sequential()
    m.add_module("flatten", nn.Flatten())
    m.add_module("in_layer", nn.Linear(28 * 28, units))
    for _ in range(layers):
        m.add_module("norm", nn.LayerNorm(units))
        m.add_module("linear", nn.Linear(units, units, bias=False))
        m.add_module("relu", nn.ReLU())
    m.add_module("linear", nn.Linear(units, 10))
    models.append(["Dense", m])

    # --- 1bit
    m = nn.Sequential()
    m.add_module("flatten", nn.Flatten())
    m.add_module("in_layer", nn.Linear(28 * 28, units))
    for _ in range(layers):
        m.add_module("bitnet", BitLinear(units, units, "1bit", flg_before_linear=False))
        m.add_module("relu", nn.ReLU())
    m.add_module("linear", nn.Linear(units, 10))
    models.append(["BitLinear 1bit", m])

    # --- 1.58bit
    m = nn.Sequential()
    m.add_module("flatten", nn.Flatten())
    m.add_module("in_layer", nn.Linear(28 * 28, units))
    for _ in range(layers):
        m.add_module("bitnet", BitLinear(units, units, "1.58bit"))
        m.add_module("relu", nn.ReLU())
    m.add_module("linear", nn.Linear(units, 10))
    models.append(["BitLinear 1.58bit", m])

    for name, m in models:
        history, times, _ = _train_mnist(m, epochs, 512, lr)
        plt.plot(times, history, label=name)

    plt.ylim(0, 1)
    plt.xlabel("Time")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid()
    plt.show()


if __name__ == "__main__":
    # simple_test()
    # @markdown 学習する回数(エポック数)を半角英数で指定します。
    epochs_number = 20 # @param {type:"integer"}
    compare(units=64, layers=5, lr=0.0001, epochs=epochs_number) # 学習回数を変更したい場合には「epochs」の数値を変更します

 

出力結果

Downloading〜

〜

train start(①「BitLinearなし」)
Epoch 1, Loss: 1.59740, Acc: 83.160%
Epoch 2, Loss: 0.70334, Acc: 88.260%

〜

Epoch 19, Loss: 0.27147, Acc: 92.110%
Epoch 20, Loss: 0.26923, Acc: 92.110%

〜

train start(②「BitLinear 1ビット」- BitNet)
Epoch 1, Loss: 1.94890, Acc: 74.660%
Epoch 2, Loss: 1.26848, Acc: 82.680%

〜

Epoch 19, Loss: 0.24331, Acc: 93.120%
Epoch 20, Loss: 0.23618, Acc: 93.110%

train start(③「BitLinear 1.58ビット」- BitNet b1.58)
Epoch 1, Loss: 1.82873, Acc: 79.900%
Epoch 2, Loss: 1.22994, Acc: 85.660%


〜

Epoch 19, Loss: 0.21964, Acc: 93.520%
Epoch 20, Loss: 0.21529, Acc: 93.370%

BitNet比較 - MNISTの学習精度の比較グラフ by 子供プログラマー
* 画像をクリックすると拡大されます。

 

 

おわりに

 

 

2024年2月に公開されて話題となった論文によると、LLMの場合には、パラメータ数が大きくなるほど、BitNet b1.58の良さが出たとのことですが、今回のようにMNISTを使った、ちょっとした実験でもBitNet b1.58の良さが出るようでしたので興味深いですね。

情報を調べる中で感じることとしては、日本人の方をはじめとして世界中のプログラマーの方々が、先人のプログラムを参考にしつつ、思い思いにプログラムを開発されているようですので、今後のBitNet関連の技術の進展や応用が楽しみですね。

今後も、少しずつではありますが、BitNetについて学び、何かしらのアウトプットができればと考えています。

 

 

1ビット:BitNet・BitNet b1.58関連論文

 

 

【1ビットTransformers論文】BitNet: Scaling 1-bit Transformers for Large Language Models

 

 

1ビットTransformers – BitNetの論文:
BitNet: Scaling 1-bit Transformers for Large Language Models – Microsoft Research et al(v1:17 Oct 2023)| arxiv
(BitNet: 大規模な言語モデル向けの 1ビットTransformersのスケーリング 2023年10月17日公開 – マイクロソフトリサーチ 他)

 

 

【1ビットLLM論文】The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits

 

 

1ビットLLM – BitNet b1.58の論文:
The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits – Microsoft Research et al(v1:27 Feb 2024)| arxiv
(1ビットLLMの時代:すべての大規模言語モデルは1.58ビットです 2024年2月27日公開 – マイクロソフトリサーチ 他)

 

1ビットLLM – BitNet b1.58の追加論文:
microsoft/unilm:bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf – Mar 20, 2024 | GitHub
(1ビットLLMの時代: トレーニングのヒント、コード、FAQ – 2024年3月20日公開)

 

 

by 子供プログラマー

 

 

MNISTで作成した学習モデルを使って、画像認識に挑戦
【1ビットの時代 – MNIST編】BitNet b1.58の基盤技術BitLinear/BitNetで画像認識

日本語LLMをファインチューニングし自分だけのカスタム対話型AIに挑戦
日本語LLMのファインチューニング入門 – 自作・Hugging Face公開データセット対応

気軽にチャットAIが始められるおすすめの拡張機能です。会員登録やログイン不要で使えるチャットAIもあります。
【使い方】ChatHub入門 – チャットAIをはじめよう