Neural Netと学習

ML
EDA
Author

Ryo Nakagami

Published

2026-03-17

Modified

2026-03-17

周期関数の学習

以下の区分的定数関数 \(f\) を考える:

\[ f(x) = \begin{cases} -1 & x < -\frac{\pi}{10} \\ 1 & -\frac{\pi}{10} \leq x < \frac{\pi}{10} \\ 0 & x \geq \frac{\pi}{10} \end{cases} \]

ただし,\(f\)\([-1, 1)\) 上で定義された関数を周期的に拡張したものとする(\(x \mapsto ((x+1) \bmod 2) - 1\) で折り返し).

フーリエ級数近似

まず,\(f\) のフーリエ級数展開を確認する.区間 \([-\pi, \pi]\) 上で周期 \(2\pi\) のフーリエ係数を数値積分で求め,\(N = 100\) 項までの部分和で近似する. 不連続点付近ではギブズ現象(Gibbs phenomenon)によるオーバーシュートが観察される.

Code
import numpy as np
import matplotlib.pyplot as plt


def f_periodic(x):
    x = ((x + 1) % 2) - 1  # [-1,1) に折り返し

    if x < -np.pi / 10:
        return -1
    elif x < np.pi / 10:
        return 1
    else:
        return 0


def fourier_coeffs(N=100, M=20000):
    L = np.pi
    xs = np.linspace(-L, L, M)
    dx = xs[1] - xs[0]
    fx = np.array([f_periodic(x) for x in xs])

    a0 = (1 / L) * np.sum(fx) * dx

    an = []
    bn = []

    for n in range(1, N + 1):
        cos_term = np.cos(n * xs)
        sin_term = np.sin(n * xs)

        an.append((1 / L) * np.sum(fx * cos_term) * dx)
        bn.append((1 / L) * np.sum(fx * sin_term) * dx)

    return a0, np.array(an), np.array(bn)


def fourier_approx(x, a0, an, bn):
    result = a0 / 2
    for n in range(1, len(an) + 1):
        result += an[n - 1] * np.cos(n * x) + bn[n - 1] * np.sin(n * x)
    return result


a0, an, bn = fourier_coeffs(N=100)

xs = np.linspace(-np.pi, np.pi, 2000)
fx = np.array([f_periodic(x) for x in xs])
approx = np.array([fourier_approx(x, a0, an, bn) for x in xs])

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(xs, fx, label="f(x)", linewidth=2)
ax.plot(xs, approx, label="Fourier (N=100)", linewidth=1, alpha=0.8)
ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()
ax.set_title("Fourier Series Approximation")
plt.tight_layout()
plt.show()

Neural Netによる関数近似

フーリエ級数は直交基底の線形結合であり,不連続関数に対しては収束が遅い. ここでは,同じ関数 \(f\) を小規模なニューラルネットワークで近似することを試みる.

入力特徴量の設計

生の \(x\) を直接入力とするのではなく,フーリエ特徴量 \(\{\sin(k\pi x), \cos(k\pi x)\}_{k=1}^{K}\) を入力として与える. これにより,ネットワーク自身が三角関数の非線形変換を学習する必要がなくなり, 少ないパラメータでも不連続点付近の急峻な変化を捉えやすくなる.

ネットワーク構成

サイズ
入力 \(2K\)\(K=5\) のとき 10次元)
隠れ層1 \(4K\) + BatchNorm + ReLU
隠れ層2 \(K\) + BatchNorm + ReLU
隠れ層3 3 + BatchNorm + ReLU
出力 1

学習の工夫

  • BatchNorm: 各層の出力を正規化し,勾配の流れを安定化
  • He初期化: ReLU活性化に適した重み初期化で収束を早める
  • AdamW (weight_decay=\(10^{-4}\)): L2正則化付きでフラットな最小値へ誘導
  • OneCycleLR: warmup → 高学習率 → annealing を1サイクルで行い,少ないepochで効率的に収束させる
  • ミニバッチ (batch_size=512): SGDノイズが暗黙の正則化として働く
  • 勾配クリッピング (max_norm=1.0): 不連続関数の学習で生じる勾配の不安定性を抑制
Code
import torch
import torch.nn as nn
import torch.optim as optim

# --- サンプリング ---
np.random.seed(42)


def f_periodic_vec(x):
    x = ((x + 1) % 2) - 1

    y = np.zeros_like(x)
    y[x < -np.pi / 10] = -1
    y[(x >= -np.pi / 10) & (x < np.pi / 10)] = 1
    y[x >= np.pi / 10] = 0

    return y


K = 5  # フーリエ特徴量の次数


def make_features(x, K=K):
    """sin(kπx), cos(kπx) for k=1..K → 2K次元の特徴量"""
    feats = []
    for k in range(1, K + 1):
        feats.append(np.sin(k * np.pi * x))
        feats.append(np.cos(k * np.pi * x))
    return np.stack(feats, axis=1)


x_train = np.random.uniform(-np.pi / 2, np.pi / 2, 10000)
y_train = f_periodic_vec(x_train)

X = torch.tensor(make_features(x_train), dtype=torch.float32)
Y = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)

model = nn.Sequential(
    nn.Linear(2 * K, 4 * K),
    nn.BatchNorm1d(4 * K),
    nn.ReLU(),
    nn.Linear(4 * K, K),
    nn.BatchNorm1d(K),
    nn.ReLU(),
    nn.Linear(K, 3),
    nn.BatchNorm1d(3),
    nn.ReLU(),
    nn.Linear(3, 1),
)

for m in model:
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        nn.init.zeros_(m.bias)

criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)

# --- ミニバッチ学習 + OneCycleLR ---
dataset = torch.utils.data.TensorDataset(X, Y)
loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True)

n_epochs = 60
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=1e-2, epochs=n_epochs, steps_per_epoch=len(loader)
)

# --- プロット用の準備 ---
xs_plot = np.linspace(-np.pi, np.pi, 500)
fx_plot = np.array([f_periodic(x) for x in xs_plot])
X_plot_tensor = torch.tensor(make_features(xs_plot), dtype=torch.float32)


# --- 学習ループ(各epochのスナップショットを記録) ---
snapshots = {}
all_losses = []

for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0.0
    for xb, yb in loader:
        pred = model(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.item() * len(xb)
    epoch_loss /= len(dataset)
    all_losses.append(epoch_loss)

    model.eval()
    with torch.no_grad():
        snap = model(X_plot_tensor).squeeze().cpu().numpy()
    snapshots[epoch] = snap.copy()

    if epoch % 10 == 0:
        print(f"epoch {epoch:3d}  loss={epoch_loss:.6f}")

# --- Plotly (frames + slider) ---
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=["Training Loss", "NN Approximation"],
    horizontal_spacing=0.12,
)

# trace 0: loss曲線
fig.add_trace(
    go.Scatter(
        x=list(range(n_epochs)),
        y=all_losses,
        mode="lines",
        name="loss",
        line=dict(color="#1f77b4"),
    ),
    row=1,
    col=1,
)
# trace 1: 現在epochマーカー(赤丸)
fig.add_trace(
    go.Scatter(
        x=[0],
        y=[all_losses[0]],
        mode="markers",
        name="current epoch",
        marker=dict(color="red", size=12),
        showlegend=False,
    ),
    row=1,
    col=1,
)
# trace 2: f(x)
fig.add_trace(
    go.Scatter(
        x=xs_plot,
        y=fx_plot,
        mode="lines",
        name="f(x)",
        line=dict(color="black", width=2),
    ),
    row=1,
    col=2,
)
# trace 3: NN出力(frameで更新)
fig.add_trace(
    go.Scatter(
        x=xs_plot,
        y=snapshots[0],
        mode="lines",
        name="NN",
        line=dict(color="#d62728", width=2),
    ),
    row=1,
    col=2,
)

# --- frames ---
frames = []
for ep in range(n_epochs):
    frames.append(
        go.Frame(
            data=[
                go.Scatter(x=[ep], y=[all_losses[ep]]),
                go.Scatter(x=xs_plot, y=snapshots[ep]),
            ],
            traces=[1, 3],
            name=str(ep),
        )
    )
fig.frames = frames

# --- slider ---
sliders = [
    dict(
        active=0,
        currentvalue=dict(prefix="epoch: "),
        pad=dict(t=30),
        steps=[
            dict(
                args=[
                    [str(ep)],
                    dict(mode="immediate", frame=dict(duration=0, redraw=True)),
                ],
                method="animate",
                label=str(ep),
            )
            for ep in range(n_epochs)
        ],
    )
]

# --- play / pause ---
updatemenus = [
    dict(
        type="buttons",
        showactive=False,
        x=0.0,
        y=-0.15,
        xanchor="left",
        buttons=[
            dict(
                label="&#9654;",
                method="animate",
                args=[
                    None,
                    dict(frame=dict(duration=100, redraw=True), fromcurrent=True),
                ],
            ),
            dict(
                label="&#9724;",
                method="animate",
                args=[
                    [None],
                    dict(frame=dict(duration=0, redraw=True), mode="immediate"),
                ],
            ),
        ],
    )
]

fig.update_xaxes(title_text="epoch", row=1, col=1)
fig.update_yaxes(title_text="MSE loss", type="log", row=1, col=1)
fig.update_xaxes(title_text="x", range=[-np.pi, np.pi], row=1, col=2)
fig.update_yaxes(title_text="y", range=[-1.5, 1.5], row=1, col=2)
fig.update_layout(
    height=500,
    width=1000,
    sliders=sliders,
    updatemenus=updatemenus,
    margin=dict(t=40, b=80),
)
fig.show()
epoch   0  loss=0.791099
epoch  10  loss=0.017091
epoch  20  loss=0.009779
epoch  30  loss=0.007226
epoch  40  loss=0.007049
epoch  50  loss=0.005311
Note
  • 左パネルのloss曲線上でスライダーを動かすと,対応するepochにおけるNNの出力が右パネルに表示される.
  • 初期(epoch 0)ではほぼランダムな出力だが,学習が進むにつれて \(f(x)\) の形状を捉えていく様子が確認できる.