1"""
211. ํ์ต ์ต์ ํ
3
4ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋, Mixed Precision, Gradient Accumulation ๋ฑ์ ๊ตฌํํฉ๋๋ค.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.utils.data import DataLoader, TensorDataset
11import numpy as np
12import math
13import time
14
15print("=" * 60)
16print("PyTorch ํ์ต ์ต์ ํ")
17print("=" * 60)
18
19device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20print(f"์ฌ์ฉ ์ฅ์น: {device}")
21
22
23# ============================================
24# 1. ์ฌํ์ฑ ์ค์
25# ============================================
26print("\n[1] ์ฌํ์ฑ ์ค์ ")
27print("-" * 40)
28
29def set_seed(seed=42):
30 """์ฌํ์ฑ์ ์ํ ์๋ ์ค์ """
31 import random
32 random.seed(seed)
33 np.random.seed(seed)
34 torch.manual_seed(seed)
35 if torch.cuda.is_available():
36 torch.cuda.manual_seed_all(seed)
37 torch.backends.cudnn.deterministic = True
38 torch.backends.cudnn.benchmark = False
39
40set_seed(42)
41print("์๋ ์ค์ ์๋ฃ: 42")
42
43
44# ============================================
45# 2. ์ํ ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํฐ
46# ============================================
47print("\n[2] ์ํ ๋ชจ๋ธ ๋ฐ ๋ฐ์ดํฐ")
48print("-" * 40)
49
50class SimpleNet(nn.Module):
51 def __init__(self, input_size=784, hidden_size=256, num_classes=10, dropout=0.5):
52 super().__init__()
53 self.fc1 = nn.Linear(input_size, hidden_size)
54 self.bn1 = nn.BatchNorm1d(hidden_size)
55 self.dropout = nn.Dropout(dropout)
56 self.fc2 = nn.Linear(hidden_size, num_classes)
57
58 def forward(self, x):
59 x = x.view(x.size(0), -1)
60 x = F.relu(self.bn1(self.fc1(x)))
61 x = self.dropout(x)
62 x = self.fc2(x)
63 return x
64
65# ๋๋ฏธ ๋ฐ์ดํฐ
66X_train = torch.randn(1000, 1, 28, 28)
67y_train = torch.randint(0, 10, (1000,))
68X_val = torch.randn(200, 1, 28, 28)
69y_val = torch.randint(0, 10, (200,))
70
71train_dataset = TensorDataset(X_train, y_train)
72val_dataset = TensorDataset(X_val, y_val)
73
74print(f"ํ๋ จ ๋ฐ์ดํฐ: {len(train_dataset)}")
75print(f"๊ฒ์ฆ ๋ฐ์ดํฐ: {len(val_dataset)}")
76
77
78# ============================================
79# 3. ํ์ต๋ฅ ์ค์ผ์ค๋ฌ
80# ============================================
81print("\n[3] ํ์ต๋ฅ ์ค์ผ์ค๋ฌ")
82print("-" * 40)
83
84def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
85 """Warmup + Cosine Decay ์ค์ผ์ค๋ฌ"""
86 def lr_lambda(current_step):
87 if current_step < warmup_steps:
88 return float(current_step) / float(max(1, warmup_steps))
89 progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
90 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
91 return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
92
93# ํ
์คํธ
94model = SimpleNet()
95optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
96scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=100, total_steps=1000)
97
98lrs = []
99for step in range(1000):
100 lrs.append(optimizer.param_groups[0]['lr'])
101 scheduler.step()
102
103print(f"Warmup ๊ตฌ๊ฐ (0-100): {lrs[0]:.6f} โ {lrs[99]:.6f}")
104print(f"Decay ๊ตฌ๊ฐ (100-1000): {lrs[100]:.6f} โ {lrs[-1]:.6f}")
105
106
107# ============================================
108# 4. ์กฐ๊ธฐ ์ข
๋ฃ
109# ============================================
110print("\n[4] ์กฐ๊ธฐ ์ข
๋ฃ")
111print("-" * 40)
112
113class EarlyStopping:
114 def __init__(self, patience=10, min_delta=0, restore_best=True):
115 self.patience = patience
116 self.min_delta = min_delta
117 self.restore_best = restore_best
118 self.counter = 0
119 self.best_loss = None
120 self.best_weights = None
121 self.early_stop = False
122
123 def __call__(self, val_loss, model):
124 if self.best_loss is None:
125 self.best_loss = val_loss
126 self._save_checkpoint(model)
127 elif val_loss > self.best_loss - self.min_delta:
128 self.counter += 1
129 if self.counter >= self.patience:
130 self.early_stop = True
131 if self.restore_best and self.best_weights is not None:
132 model.load_state_dict(self.best_weights)
133 else:
134 self.best_loss = val_loss
135 self._save_checkpoint(model)
136 self.counter = 0
137
138 def _save_checkpoint(self, model):
139 self.best_weights = {k: v.cpu().clone() for k, v in model.state_dict().items()}
140
141# ํ
์คํธ
142early_stopping = EarlyStopping(patience=3)
143losses = [1.0, 0.9, 0.8, 0.85, 0.86, 0.87, 0.88]
144
145print("์กฐ๊ธฐ ์ข
๋ฃ ์๋ฎฌ๋ ์ด์
:")
146for epoch, loss in enumerate(losses):
147 early_stopping(loss, model)
148 status = "STOP" if early_stopping.early_stop else f"counter={early_stopping.counter}"
149 print(f" Epoch {epoch}: loss={loss:.2f}, {status}")
150 if early_stopping.early_stop:
151 break
152
153
154# ============================================
155# 5. Gradient Accumulation
156# ============================================
157print("\n[5] Gradient Accumulation")
158print("-" * 40)
159
160def train_with_accumulation(model, train_loader, optimizer, accumulation_steps=4):
161 """Gradient Accumulation์ผ๋ก ํ์ต"""
162 model.train()
163 optimizer.zero_grad()
164 total_loss = 0
165
166 for i, (data, target) in enumerate(train_loader):
167 data, target = data.to(device), target.to(device)
168
169 output = model(data)
170 loss = F.cross_entropy(output, target)
171 loss = loss / accumulation_steps # ์ค์ผ์ผ๋ง
172 loss.backward()
173
174 if (i + 1) % accumulation_steps == 0:
175 optimizer.step()
176 optimizer.zero_grad()
177
178 total_loss += loss.item() * accumulation_steps
179
180 return total_loss / len(train_loader)
181
182# ํ
์คํธ
183model = SimpleNet().to(device)
184optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
185train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
186
187loss = train_with_accumulation(model, train_loader, optimizer, accumulation_steps=4)
188print(f"Accumulation ํ์ต ์์ค: {loss:.4f}")
189print(f"ํจ๊ณผ์ ๋ฐฐ์น ํฌ๊ธฐ: 32 ร 4 = 128")
190
191
192# ============================================
193# 6. Mixed Precision Training
194# ============================================
195print("\n[6] Mixed Precision Training")
196print("-" * 40)
197
198if torch.cuda.is_available():
199 from torch.cuda.amp import autocast, GradScaler
200
201 def train_with_amp(model, train_loader, optimizer, scaler):
202 """Mixed Precision ํ์ต"""
203 model.train()
204 total_loss = 0
205
206 for data, target in train_loader:
207 data, target = data.to(device), target.to(device)
208 optimizer.zero_grad()
209
210 with autocast():
211 output = model(data)
212 loss = F.cross_entropy(output, target)
213
214 scaler.scale(loss).backward()
215 scaler.step(optimizer)
216 scaler.update()
217
218 total_loss += loss.item()
219
220 return total_loss / len(train_loader)
221
222 model = SimpleNet().to(device)
223 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
224 scaler = GradScaler()
225
226 loss = train_with_amp(model, train_loader, optimizer, scaler)
227 print(f"AMP ํ์ต ์์ค: {loss:.4f}")
228else:
229 print("CUDA ๋ฏธ์ฌ์ฉ - AMP ์คํต")
230
231
232# ============================================
233# 7. Gradient Clipping
234# ============================================
235print("\n[7] Gradient Clipping")
236print("-" * 40)
237
238def train_with_clipping(model, train_loader, optimizer, max_norm=1.0):
239 """Gradient Clipping์ผ๋ก ํ์ต"""
240 model.train()
241 total_loss = 0
242 grad_norms = []
243
244 for data, target in train_loader:
245 data, target = data.to(device), target.to(device)
246 optimizer.zero_grad()
247
248 output = model(data)
249 loss = F.cross_entropy(output, target)
250 loss.backward()
251
252 # Gradient norm ๊ธฐ๋ก (ํด๋ฆฌํ ์ )
253 total_norm = 0
254 for p in model.parameters():
255 if p.grad is not None:
256 total_norm += p.grad.data.norm(2).item() ** 2
257 grad_norms.append(total_norm ** 0.5)
258
259 # Gradient Clipping
260 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
261
262 optimizer.step()
263 total_loss += loss.item()
264
265 return total_loss / len(train_loader), grad_norms
266
267model = SimpleNet().to(device)
268optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
269
270loss, norms = train_with_clipping(model, train_loader, optimizer, max_norm=1.0)
271print(f"Clipping ํ์ต ์์ค: {loss:.4f}")
272print(f"ํ๊ท ๊ธฐ์ธ๊ธฐ norm: {np.mean(norms):.4f}")
273print(f"์ต๋ ๊ธฐ์ธ๊ธฐ norm: {np.max(norms):.4f}")
274
275
276# ============================================
277# 8. ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ (Random Search)
278# ============================================
279print("\n[8] ํ์ดํผํ๋ผ๋ฏธํฐ ํ์")
280print("-" * 40)
281
282def evaluate(model, val_loader):
283 model.eval()
284 correct = 0
285 total = 0
286 with torch.no_grad():
287 for data, target in val_loader:
288 data, target = data.to(device), target.to(device)
289 output = model(data)
290 pred = output.argmax(dim=1)
291 correct += (pred == target).sum().item()
292 total += target.size(0)
293 return correct / total
294
295def train_with_config(lr, batch_size, dropout, epochs=5):
296 """์ค์ ์ผ๋ก ํ์ต"""
297 set_seed(42)
298
299 model = SimpleNet(dropout=dropout).to(device)
300 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
301
302 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
303 val_loader = DataLoader(val_dataset, batch_size=batch_size)
304
305 for epoch in range(epochs):
306 model.train()
307 for data, target in train_loader:
308 data, target = data.to(device), target.to(device)
309 optimizer.zero_grad()
310 loss = F.cross_entropy(model(data), target)
311 loss.backward()
312 optimizer.step()
313
314 return evaluate(model, val_loader)
315
316# Random Search
317import random
318print("Random Search ์คํ ์ค...")
319
320best_acc = 0
321best_config = None
322results = []
323
324for trial in range(5):
325 lr = 10 ** random.uniform(-4, -2)
326 batch_size = random.choice([32, 64, 128])
327 dropout = random.uniform(0.2, 0.5)
328
329 acc = train_with_config(lr, batch_size, dropout, epochs=3)
330 results.append((lr, batch_size, dropout, acc))
331
332 if acc > best_acc:
333 best_acc = acc
334 best_config = (lr, batch_size, dropout)
335
336 print(f" Trial {trial+1}: lr={lr:.6f}, bs={batch_size}, dropout={dropout:.2f} โ acc={acc:.4f}")
337
338print(f"\n์ต์ ์ค์ : lr={best_config[0]:.6f}, bs={best_config[1]}, dropout={best_config[2]:.2f}")
339print(f"์ต๊ณ ์ ํ๋: {best_acc:.4f}")
340
341
342# ============================================
343# 9. ์ ์ฒด ํ์ต ํ์ดํ๋ผ์ธ
344# ============================================
345print("\n[9] ์ ์ฒด ํ์ต ํ์ดํ๋ผ์ธ")
346print("-" * 40)
347
348def full_training_pipeline(config):
349 """์ต์ ํ ๊ธฐ๋ฒ์ด ์ ์ฉ๋ ์ ์ฒด ํ์ต ํ์ดํ๋ผ์ธ"""
350 set_seed(config['seed'])
351
352 # ๋ชจ๋ธ
353 model = SimpleNet(dropout=config['dropout']).to(device)
354
355 # ์ตํฐ๋ง์ด์
356 optimizer = torch.optim.AdamW(
357 model.parameters(),
358 lr=config['lr'],
359 weight_decay=config['weight_decay']
360 )
361
362 # ๋ฐ์ดํฐ ๋ก๋
363 train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
364 val_loader = DataLoader(val_dataset, batch_size=config['batch_size'])
365
366 # ์ค์ผ์ค๋ฌ
367 total_steps = len(train_loader) * config['epochs']
368 warmup_steps = int(total_steps * config['warmup_ratio'])
369 scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
370
371 # ์กฐ๊ธฐ ์ข
๋ฃ
372 early_stopping = EarlyStopping(patience=config['patience'])
373
374 # AMP (CUDA์ธ ๊ฒฝ์ฐ)
375 use_amp = torch.cuda.is_available()
376 scaler = torch.cuda.amp.GradScaler() if use_amp else None
377
378 # ํ์ต
379 history = {'train_loss': [], 'val_loss': [], 'val_acc': [], 'lr': []}
380
381 for epoch in range(config['epochs']):
382 # ํ๋ จ
383 model.train()
384 train_loss = 0
385 for data, target in train_loader:
386 data, target = data.to(device), target.to(device)
387 optimizer.zero_grad()
388
389 if use_amp:
390 with torch.cuda.amp.autocast():
391 output = model(data)
392 loss = F.cross_entropy(output, target)
393 scaler.scale(loss).backward()
394 scaler.unscale_(optimizer)
395 torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
396 scaler.step(optimizer)
397 scaler.update()
398 else:
399 output = model(data)
400 loss = F.cross_entropy(output, target)
401 loss.backward()
402 torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
403 optimizer.step()
404
405 scheduler.step()
406 train_loss += loss.item()
407
408 train_loss /= len(train_loader)
409
410 # ๊ฒ์ฆ
411 model.eval()
412 val_loss = 0
413 correct = 0
414 total = 0
415 with torch.no_grad():
416 for data, target in val_loader:
417 data, target = data.to(device), target.to(device)
418 output = model(data)
419 val_loss += F.cross_entropy(output, target).item()
420 pred = output.argmax(dim=1)
421 correct += (pred == target).sum().item()
422 total += target.size(0)
423
424 val_loss /= len(val_loader)
425 val_acc = correct / total
426
427 # ๊ธฐ๋ก
428 history['train_loss'].append(train_loss)
429 history['val_loss'].append(val_loss)
430 history['val_acc'].append(val_acc)
431 history['lr'].append(optimizer.param_groups[0]['lr'])
432
433 # ์กฐ๊ธฐ ์ข
๋ฃ ์ฒดํฌ
434 early_stopping(val_loss, model)
435 if early_stopping.early_stop:
436 print(f" ์กฐ๊ธฐ ์ข
๋ฃ at epoch {epoch+1}")
437 break
438
439 if (epoch + 1) % 5 == 0:
440 print(f" Epoch {epoch+1}: train_loss={train_loss:.4f}, val_acc={val_acc:.4f}")
441
442 return model, history
443
444# ์ค์
445config = {
446 'seed': 42,
447 'lr': 1e-3,
448 'batch_size': 64,
449 'epochs': 20,
450 'dropout': 0.3,
451 'weight_decay': 0.01,
452 'warmup_ratio': 0.1,
453 'patience': 5,
454 'max_grad_norm': 1.0
455}
456
457print("์ ์ฒด ํ์ดํ๋ผ์ธ ์คํ ์ค...")
458model, history = full_training_pipeline(config)
459print(f"\n์ต์ข
๊ฒ์ฆ ์ ํ๋: {history['val_acc'][-1]:.4f}")
460
461
462# ============================================
463# ์ ๋ฆฌ
464# ============================================
465print("\n" + "=" * 60)
466print("ํ์ต ์ต์ ํ ์ ๋ฆฌ")
467print("=" * 60)
468
469summary = """
470ํต์ฌ ๊ธฐ๋ฒ:
471
4721. ํ์ต๋ฅ ์ค์ผ์ค๋ง
473 - Warmup: ์ด๊ธฐ ์์ ํ
474 - Cosine Decay: ์ ์ง์ ๊ฐ์
475 - OneCycleLR: ๋ฐฐ์น๋ง๋ค ์กฐ์
476
4772. Mixed Precision (AMP)
478 - ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ, ์๋ ํฅ์
479 - autocast() + GradScaler()
480
4813. Gradient Accumulation
482 - ์์ ๋ฐฐ์น โ ํฐ ๋ฐฐ์น ํจ๊ณผ
483 - loss /= accumulation_steps
484
4854. Gradient Clipping
486 - ๊ธฐ์ธ๊ธฐ ํญ๋ฐ ๋ฐฉ์ง
487 - clip_grad_norm_(params, max_norm)
488
4895. ์กฐ๊ธฐ ์ข
๋ฃ
490 - ๊ณผ์ ํฉ ๋ฐฉ์ง
491 - ์ต์ ๊ฐ์ค์น ๋ณต์
492
493๊ถ์ฅ ์ค์ :
494 optimizer = AdamW(lr=1e-4, weight_decay=0.01)
495 scheduler = OneCycleLR(max_lr=1e-3)
496 scaler = GradScaler()
497 early_stopping = EarlyStopping(patience=10)
498"""
499print(summary)
500print("=" * 60)