04_training_techniques.py

Download
python 426 lines 11.9 KB
  1"""
  204. ํ•™์Šต ๊ธฐ๋ฒ• - PyTorch ๋ฒ„์ „
  3
  4๋‹ค์–‘ํ•œ ์ตœ์ ํ™” ๊ธฐ๋ฒ•๊ณผ ์ •๊ทœํ™”๋ฅผ PyTorch๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5NumPy ๋ฒ„์ „(examples/numpy/04_training_techniques.py)๊ณผ ๋น„๊ตํ•ด ๋ณด์„ธ์š”.
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11from torch.utils.data import DataLoader, TensorDataset
 12import numpy as np
 13import matplotlib.pyplot as plt
 14
 15print("=" * 60)
 16print("PyTorch ํ•™์Šต ๊ธฐ๋ฒ•")
 17print("=" * 60)
 18
 19
 20# ============================================
 21# 1. ์˜ตํ‹ฐ๋งˆ์ด์ € ๋น„๊ต
 22# ============================================
 23print("\n[1] ์˜ตํ‹ฐ๋งˆ์ด์ € ๋น„๊ต")
 24print("-" * 40)
 25
 26# ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ์ •์˜
 27class SimpleNet(nn.Module):
 28    def __init__(self):
 29        super().__init__()
 30        self.fc1 = nn.Linear(2, 16)
 31        self.fc2 = nn.Linear(16, 1)
 32
 33    def forward(self, x):
 34        x = F.relu(self.fc1(x))
 35        x = torch.sigmoid(self.fc2(x))
 36        return x
 37
 38# XOR ๋ฐ์ดํ„ฐ
 39X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
 40y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)
 41
 42def train_with_optimizer(optimizer_class, **kwargs):
 43    """์ฃผ์–ด์ง„ ์˜ตํ‹ฐ๋งˆ์ด์ €๋กœ ํ•™์Šต"""
 44    torch.manual_seed(42)
 45    model = SimpleNet()
 46    optimizer = optimizer_class(model.parameters(), **kwargs)
 47    criterion = nn.BCELoss()
 48
 49    losses = []
 50    for epoch in range(500):
 51        pred = model(X)
 52        loss = criterion(pred, y)
 53        losses.append(loss.item())
 54
 55        optimizer.zero_grad()
 56        loss.backward()
 57        optimizer.step()
 58
 59    return losses
 60
 61# ๋‹ค์–‘ํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ € ํ…Œ์ŠคํŠธ
 62optimizers = {
 63    'SGD (lr=0.5)': (torch.optim.SGD, {'lr': 0.5}),
 64    'SGD+Momentum': (torch.optim.SGD, {'lr': 0.5, 'momentum': 0.9}),
 65    'Adam': (torch.optim.Adam, {'lr': 0.01}),
 66    'RMSprop': (torch.optim.RMSprop, {'lr': 0.01}),
 67}
 68
 69results = {}
 70for name, (opt_class, params) in optimizers.items():
 71    losses = train_with_optimizer(opt_class, **params)
 72    results[name] = losses
 73    print(f"{name}: ์ตœ์ข… ์†์‹ค = {losses[-1]:.6f}")
 74
 75# ์‹œ๊ฐํ™”
 76plt.figure(figsize=(10, 5))
 77for name, losses in results.items():
 78    plt.plot(losses, label=name)
 79plt.xlabel('Epoch')
 80plt.ylabel('Loss')
 81plt.title('Optimizer Comparison')
 82plt.legend()
 83plt.yscale('log')
 84plt.grid(True, alpha=0.3)
 85plt.savefig('optimizer_comparison.png', dpi=100)
 86plt.close()
 87print("๊ทธ๋ž˜ํ”„ ์ €์žฅ: optimizer_comparison.png")
 88
 89
 90# ============================================
 91# 2. ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ
 92# ============================================
 93print("\n[2] ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ")
 94print("-" * 40)
 95
 96# ์Šค์ผ€์ค„๋Ÿฌ ํ…Œ์ŠคํŠธ
 97def test_scheduler(scheduler_class, **kwargs):
 98    torch.manual_seed(42)
 99    model = SimpleNet()
100    optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
101    scheduler = scheduler_class(optimizer, **kwargs)
102
103    lrs = []
104    for epoch in range(100):
105        lrs.append(optimizer.param_groups[0]['lr'])
106        scheduler.step()
107
108    return lrs
109
110schedulers = {
111    'StepLR': (torch.optim.lr_scheduler.StepLR, {'step_size': 20, 'gamma': 0.5}),
112    'ExponentialLR': (torch.optim.lr_scheduler.ExponentialLR, {'gamma': 0.95}),
113    'CosineAnnealingLR': (torch.optim.lr_scheduler.CosineAnnealingLR, {'T_max': 50}),
114}
115
116plt.figure(figsize=(10, 5))
117for name, (sched_class, params) in schedulers.items():
118    lrs = test_scheduler(sched_class, **params)
119    plt.plot(lrs, label=name)
120    print(f"{name}: ์‹œ์ž‘ {lrs[0]:.4f} โ†’ ๋ {lrs[-1]:.4f}")
121
122plt.xlabel('Epoch')
123plt.ylabel('Learning Rate')
124plt.title('Learning Rate Schedulers')
125plt.legend()
126plt.grid(True, alpha=0.3)
127plt.savefig('lr_schedulers.png', dpi=100)
128plt.close()
129print("๊ทธ๋ž˜ํ”„ ์ €์žฅ: lr_schedulers.png")
130
131
132# ============================================
133# 3. Dropout
134# ============================================
135print("\n[3] Dropout")
136print("-" * 40)
137
138class NetWithDropout(nn.Module):
139    def __init__(self, dropout_p=0.5):
140        super().__init__()
141        self.fc1 = nn.Linear(2, 32)
142        self.dropout = nn.Dropout(p=dropout_p)
143        self.fc2 = nn.Linear(32, 1)
144
145    def forward(self, x):
146        x = F.relu(self.fc1(x))
147        x = self.dropout(x)
148        x = torch.sigmoid(self.fc2(x))
149        return x
150
151# Dropout ํšจ๊ณผ ํ™•์ธ
152model = NetWithDropout(dropout_p=0.5)
153x_test = torch.randn(1, 2)
154
155model.train()
156print("ํ›ˆ๋ จ ๋ชจ๋“œ (Dropout ํ™œ์„ฑ):")
157for i in range(3):
158    out = model.fc1(x_test)
159    out = F.relu(out)
160    out = model.dropout(out)
161    print(f"  ์‹œ๋„ {i+1}: ํ™œ์„ฑ ๋‰ด๋Ÿฐ = {(out != 0).sum().item()}/32")
162
163model.eval()
164print("\nํ‰๊ฐ€ ๋ชจ๋“œ (Dropout ๋น„ํ™œ์„ฑ):")
165out = model.fc1(x_test)
166out = F.relu(out)
167out = model.dropout(out)  # eval ๋ชจ๋“œ์—์„œ๋Š” ์ „์ฒด ํ†ต๊ณผ
168print(f"  ํ™œ์„ฑ ๋‰ด๋Ÿฐ = {(out != 0).sum().item()}/32")
169
170
171# ============================================
172# 4. Batch Normalization
173# ============================================
174print("\n[4] Batch Normalization")
175print("-" * 40)
176
177class NetWithBatchNorm(nn.Module):
178    def __init__(self):
179        super().__init__()
180        self.fc1 = nn.Linear(2, 16)
181        self.bn1 = nn.BatchNorm1d(16)
182        self.fc2 = nn.Linear(16, 1)
183
184    def forward(self, x):
185        x = self.fc1(x)
186        x = self.bn1(x)
187        x = F.relu(x)
188        x = torch.sigmoid(self.fc2(x))
189        return x
190
191bn_model = NetWithBatchNorm()
192print(f"BatchNorm1d ํŒŒ๋ผ๋ฏธํ„ฐ:")
193print(f"  weight (ฮณ): {bn_model.bn1.weight.shape}")
194print(f"  bias (ฮฒ): {bn_model.bn1.bias.shape}")
195print(f"  running_mean: {bn_model.bn1.running_mean.shape}")
196print(f"  running_var: {bn_model.bn1.running_var.shape}")
197
198# ํ›ˆ๋ จ vs ํ‰๊ฐ€ ๋ชจ๋“œ
199x_batch = torch.randn(32, 2)
200
201bn_model.train()
202out_train = bn_model.fc1(x_batch)
203out_train = bn_model.bn1(out_train)
204print(f"\nํ›ˆ๋ จ ๋ชจ๋“œ - ์ถœ๋ ฅ ํ†ต๊ณ„:")
205print(f"  mean: {out_train.mean(dim=0)[:3].tolist()}")
206print(f"  std: {out_train.std(dim=0)[:3].tolist()}")
207
208bn_model.eval()
209out_eval = bn_model.fc1(x_batch)
210out_eval = bn_model.bn1(out_eval)
211print(f"ํ‰๊ฐ€ ๋ชจ๋“œ - ์ถœ๋ ฅ ํ†ต๊ณ„:")
212print(f"  mean: {out_eval.mean(dim=0)[:3].tolist()}")
213
214
215# ============================================
216# 5. Weight Decay (L2 ์ •๊ทœํ™”)
217# ============================================
218print("\n[5] Weight Decay")
219print("-" * 40)
220
221def train_with_weight_decay(weight_decay):
222    torch.manual_seed(42)
223    model = SimpleNet()
224    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=weight_decay)
225    criterion = nn.BCELoss()
226
227    for epoch in range(500):
228        pred = model(X)
229        loss = criterion(pred, y)
230        optimizer.zero_grad()
231        loss.backward()
232        optimizer.step()
233
234    # ๊ฐ€์ค‘์น˜ ํฌ๊ธฐ ํ™•์ธ
235    weight_norm = sum(p.norm().item() for p in model.parameters())
236    return loss.item(), weight_norm
237
238for wd in [0, 0.01, 0.1]:
239    loss, w_norm = train_with_weight_decay(wd)
240    print(f"Weight Decay={wd}: ์†์‹ค={loss:.4f}, ๊ฐ€์ค‘์น˜ norm={w_norm:.4f}")
241
242
243# ============================================
244# 6. ์กฐ๊ธฐ ์ข…๋ฃŒ (Early Stopping)
245# ============================================
246print("\n[6] ์กฐ๊ธฐ ์ข…๋ฃŒ")
247print("-" * 40)
248
249class EarlyStopping:
250    def __init__(self, patience=10, min_delta=0):
251        self.patience = patience
252        self.min_delta = min_delta
253        self.counter = 0
254        self.best_loss = None
255        self.early_stop = False
256        self.best_model = None
257
258    def __call__(self, val_loss, model):
259        if self.best_loss is None:
260            self.best_loss = val_loss
261            self.best_model = model.state_dict().copy()
262        elif val_loss > self.best_loss - self.min_delta:
263            self.counter += 1
264            if self.counter >= self.patience:
265                self.early_stop = True
266        else:
267            self.best_loss = val_loss
268            self.best_model = model.state_dict().copy()
269            self.counter = 0
270
271# ๋ฐ๋ชจ (์‹œ๋ฎฌ๋ ˆ์ด์…˜๋œ ๊ฒ€์ฆ ์†์‹ค)
272early_stopping = EarlyStopping(patience=5)
273val_losses = [1.0, 0.9, 0.85, 0.8, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87]
274
275model = SimpleNet()
276for epoch, val_loss in enumerate(val_losses):
277    early_stopping(val_loss, model)
278    status = "STOP" if early_stopping.early_stop else f"patience={early_stopping.counter}"
279    print(f"Epoch {epoch+1}: val_loss={val_loss:.2f}, {status}")
280    if early_stopping.early_stop:
281        break
282
283
284# ============================================
285# 7. ์ „์ฒด ํ•™์Šต ์˜ˆ์ œ
286# ============================================
287print("\n[7] ์ „์ฒด ํ•™์Šต ์˜ˆ์ œ")
288print("-" * 40)
289
290# ๋” ํฐ ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ
291np.random.seed(42)
292n_samples = 200
293
294# ์›ํ˜• ๋ฐ์ดํ„ฐ (๋น„์„ ํ˜• ๋ฌธ์ œ)
295theta = np.random.uniform(0, 2*np.pi, n_samples)
296r = np.random.uniform(0, 1, n_samples)
297X_train = np.column_stack([r * np.cos(theta), r * np.sin(theta)])
298y_train = (r > 0.5).astype(np.float32)
299
300X_train = torch.tensor(X_train, dtype=torch.float32)
301y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
302
303# ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ
304X_val = X_train[:40]
305y_val = y_train[:40]
306X_train = X_train[40:]
307y_train = y_train[40:]
308
309# DataLoader
310train_dataset = TensorDataset(X_train, y_train)
311train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
312
313class FullModel(nn.Module):
314    def __init__(self):
315        super().__init__()
316        self.fc1 = nn.Linear(2, 32)
317        self.bn1 = nn.BatchNorm1d(32)
318        self.dropout1 = nn.Dropout(0.3)
319        self.fc2 = nn.Linear(32, 16)
320        self.bn2 = nn.BatchNorm1d(16)
321        self.dropout2 = nn.Dropout(0.3)
322        self.fc3 = nn.Linear(16, 1)
323
324    def forward(self, x):
325        x = F.relu(self.bn1(self.fc1(x)))
326        x = self.dropout1(x)
327        x = F.relu(self.bn2(self.fc2(x)))
328        x = self.dropout2(x)
329        x = torch.sigmoid(self.fc3(x))
330        return x
331
332# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
333torch.manual_seed(42)
334model = FullModel()
335criterion = nn.BCELoss()
336optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
337scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)
338early_stopping = EarlyStopping(patience=20)
339
340# ํ•™์Šต
341train_losses = []
342val_losses = []
343
344for epoch in range(200):
345    # ํ›ˆ๋ จ
346    model.train()
347    epoch_loss = 0
348    for X_batch, y_batch in train_loader:
349        pred = model(X_batch)
350        loss = criterion(pred, y_batch)
351
352        optimizer.zero_grad()
353        loss.backward()
354        optimizer.step()
355        epoch_loss += loss.item()
356
357    train_losses.append(epoch_loss / len(train_loader))
358
359    # ๊ฒ€์ฆ
360    model.eval()
361    with torch.no_grad():
362        val_pred = model(X_val)
363        val_loss = criterion(val_pred, y_val).item()
364        val_losses.append(val_loss)
365
366    # ์Šค์ผ€์ค„๋Ÿฌ ์—…๋ฐ์ดํŠธ
367    scheduler.step(val_loss)
368
369    # ์กฐ๊ธฐ ์ข…๋ฃŒ ์ฒดํฌ
370    early_stopping(val_loss, model)
371
372    if (epoch + 1) % 40 == 0:
373        lr = optimizer.param_groups[0]['lr']
374        print(f"Epoch {epoch+1}: train={train_losses[-1]:.4f}, val={val_loss:.4f}, lr={lr:.6f}")
375
376    if early_stopping.early_stop:
377        print(f"์กฐ๊ธฐ ์ข…๋ฃŒ at epoch {epoch+1}")
378        break
379
380# ์ตœ๊ณ  ๋ชจ๋ธ ๋ณต์›
381if early_stopping.best_model:
382    model.load_state_dict(early_stopping.best_model)
383
384# ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
385plt.figure(figsize=(10, 5))
386plt.plot(train_losses, label='Train Loss')
387plt.plot(val_losses, label='Val Loss')
388plt.xlabel('Epoch')
389plt.ylabel('Loss')
390plt.title('Training with Regularization')
391plt.legend()
392plt.grid(True, alpha=0.3)
393plt.savefig('full_training.png', dpi=100)
394plt.close()
395print("๊ทธ๋ž˜ํ”„ ์ €์žฅ: full_training.png")
396
397
398# ============================================
399# ์ •๋ฆฌ
400# ============================================
401print("\n" + "=" * 60)
402print("ํ•™์Šต ๊ธฐ๋ฒ• ์ •๋ฆฌ")
403print("=" * 60)
404
405summary = """
406๊ถŒ์žฅ ๊ธฐ๋ณธ ์„ค์ •:
407    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
408    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
409    EarlyStopping(patience=10)
410
411์ •๊ทœํ™” ์กฐํ•ฉ:
412    - Dropout (0.2~0.5): ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
413    - BatchNorm: ํ•™์Šต ์•ˆ์ •ํ™”
414    - Weight Decay (1e-4~1e-2): ๊ฐ€์ค‘์น˜ ํฌ๊ธฐ ์ œํ•œ
415
416ํ•™์Šต ๋ฃจํ”„ ์ฒดํฌ๋ฆฌ์ŠคํŠธ:
417    1. model.train() / model.eval() ๋ชจ๋“œ ์ „ํ™˜
418    2. optimizer.zero_grad() ํ˜ธ์ถœ
419    3. loss.backward()
420    4. optimizer.step()
421    5. scheduler.step() (์—ํญ ๋)
422    6. EarlyStopping ์ฒดํฌ
423"""
424print(summary)
425print("=" * 60)