10_tensor_ops_einsum.py

Download
python 299 lines 9.7 KB
  1"""
  2Tensor Operations and Einstein Summation (einsum)
  3
  4This script demonstrates:
  5- Tensor creation and manipulation in NumPy and PyTorch
  6- Einstein summation notation (einsum) for efficient operations
  7- Broadcasting rules and examples
  8- Numerical stability techniques
  9
 10Einstein summation is a concise notation for tensor operations:
 11- Implicit mode: repeated indices are summed
 12- Explicit mode: specify output indices
 13- Used extensively in deep learning (attention, tensor contractions)
 14"""
 15
 16import numpy as np
 17import torch
 18import torch.nn.functional as F
 19
 20
 21def tensor_basics():
 22    """
 23    Basic tensor operations in NumPy and PyTorch.
 24    """
 25    print("=== Tensor Basics ===\n")
 26
 27    # NumPy
 28    np_array = np.array([[1, 2, 3], [4, 5, 6]])
 29    print(f"NumPy array shape: {np_array.shape}")
 30    print(f"NumPy array:\n{np_array}\n")
 31
 32    # PyTorch
 33    torch_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
 34    print(f"PyTorch tensor shape: {torch_tensor.shape}")
 35    print(f"PyTorch tensor:\n{torch_tensor}\n")
 36
 37    # Conversion
 38    np_from_torch = torch_tensor.numpy()
 39    torch_from_np = torch.from_numpy(np_array)
 40    print(f"Converted NumPy → PyTorch → NumPy: {np.array_equal(np_array, np_from_torch)}\n")
 41
 42
 43def einsum_examples():
 44    """
 45    Comprehensive examples of einsum operations.
 46
 47    Einsum notation:
 48    - 'i,i->': dot product (sum over i)
 49    - 'ij,jk->ik': matrix multiplication
 50    - 'ij->ji': transpose
 51    - 'ii->i': diagonal extraction
 52    - 'ij->': sum all elements
 53    """
 54    print("=== Einstein Summation (einsum) Examples ===\n")
 55
 56    # Example 1: Dot product
 57    a = np.array([1, 2, 3])
 58    b = np.array([4, 5, 6])
 59
 60    dot_manual = np.sum(a * b)
 61    dot_einsum = np.einsum('i,i->', a, b)
 62    print(f"1. Dot product:")
 63    print(f"   Manual: {dot_manual}")
 64    print(f"   Einsum 'i,i->': {dot_einsum}\n")
 65
 66    # Example 2: Matrix multiplication
 67    A = np.random.randn(3, 4)
 68    B = np.random.randn(4, 5)
 69
 70    matmul_manual = np.matmul(A, B)
 71    matmul_einsum = np.einsum('ij,jk->ik', A, B)
 72    print(f"2. Matrix multiplication (3x4 @ 4x5 = 3x5):")
 73    print(f"   Close match: {np.allclose(matmul_manual, matmul_einsum)}\n")
 74
 75    # Example 3: Transpose
 76    C = np.random.randn(3, 4)
 77    transpose_manual = C.T
 78    transpose_einsum = np.einsum('ij->ji', C)
 79    print(f"3. Transpose:")
 80    print(f"   Close match: {np.allclose(transpose_manual, transpose_einsum)}\n")
 81
 82    # Example 4: Trace (sum of diagonal)
 83    D = np.random.randn(4, 4)
 84    trace_manual = np.trace(D)
 85    trace_einsum = np.einsum('ii->', D)
 86    print(f"4. Trace (sum of diagonal):")
 87    print(f"   Manual: {trace_manual:.4f}")
 88    print(f"   Einsum 'ii->': {trace_einsum:.4f}\n")
 89
 90    # Example 5: Batch matrix multiplication
 91    batch_A = np.random.randn(8, 3, 4)  # batch of 8 matrices (3x4)
 92    batch_B = np.random.randn(8, 4, 5)  # batch of 8 matrices (4x5)
 93
 94    batch_matmul_manual = np.matmul(batch_A, batch_B)
 95    batch_matmul_einsum = np.einsum('bij,bjk->bik', batch_A, batch_B)
 96    print(f"5. Batch matrix multiplication (8 x [3x4 @ 4x5]):")
 97    print(f"   Close match: {np.allclose(batch_matmul_manual, batch_matmul_einsum)}\n")
 98
 99    # Example 6: Outer product
100    x = np.array([1, 2, 3])
101    y = np.array([4, 5])
102    outer_manual = np.outer(x, y)
103    outer_einsum = np.einsum('i,j->ij', x, y)
104    print(f"6. Outer product:")
105    print(f"   Close match: {np.allclose(outer_manual, outer_einsum)}\n")
106
107    # Example 7: Hadamard (element-wise) product and sum
108    E = np.random.randn(3, 4)
109    F = np.random.randn(3, 4)
110    hadamard_sum = np.sum(E * F)
111    hadamard_einsum = np.einsum('ij,ij->', E, F)
112    print(f"7. Hadamard product and sum:")
113    print(f"   Manual: {hadamard_sum:.4f}")
114    print(f"   Einsum 'ij,ij->': {hadamard_einsum:.4f}\n")
115
116
117def attention_with_einsum():
118    """
119    Implement scaled dot-product attention using einsum.
120
121    Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
122
123    This is more efficient than explicit reshaping and matmul.
124    """
125    print("=== Attention Mechanism with Einsum ===\n")
126
127    batch_size = 2
128    seq_len_q = 10
129    seq_len_k = 12
130    d_model = 64
131
132    # Query, Key, Value
133    Q = torch.randn(batch_size, seq_len_q, d_model)
134    K = torch.randn(batch_size, seq_len_k, d_model)
135    V = torch.randn(batch_size, seq_len_k, d_model)
136
137    # Method 1: Manual matmul
138    scores_manual = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_model)
139    attn_weights_manual = F.softmax(scores_manual, dim=-1)
140    output_manual = torch.matmul(attn_weights_manual, V)
141
142    # Method 2: Using einsum
143    scores_einsum = torch.einsum('bqd,bkd->bqk', Q, K) / np.sqrt(d_model)
144    attn_weights_einsum = F.softmax(scores_einsum, dim=-1)
145    output_einsum = torch.einsum('bqk,bkd->bqd', attn_weights_einsum, V)
146
147    print(f"Attention output shape: {output_manual.shape}")
148    print(f"Manual vs Einsum match: {torch.allclose(output_manual, output_einsum, atol=1e-6)}\n")
149
150    # Multi-head attention (einsum is cleaner here)
151    num_heads = 8
152    d_k = d_model // num_heads
153
154    # Split into multiple heads: (batch, seq, d_model) -> (batch, num_heads, seq, d_k)
155    Q_heads = Q.reshape(batch_size, seq_len_q, num_heads, d_k).transpose(1, 2)
156    K_heads = K.reshape(batch_size, seq_len_k, num_heads, d_k).transpose(1, 2)
157    V_heads = V.reshape(batch_size, seq_len_k, num_heads, d_k).transpose(1, 2)
158
159    # Compute attention per head using einsum
160    scores_heads = torch.einsum('bhqd,bhkd->bhqk', Q_heads, K_heads) / np.sqrt(d_k)
161    attn_weights_heads = F.softmax(scores_heads, dim=-1)
162    output_heads = torch.einsum('bhqk,bhkd->bhqd', attn_weights_heads, V_heads)
163
164    # Concatenate heads
165    output_multihead = output_heads.transpose(1, 2).reshape(batch_size, seq_len_q, d_model)
166
167    print(f"Multi-head attention output shape: {output_multihead.shape}")
168    print(f"Attention weights shape (per head): {attn_weights_heads.shape}\n")
169
170
171def broadcasting_examples():
172    """
173    Demonstrate broadcasting rules in NumPy and PyTorch.
174
175    Broadcasting rules:
176    1. If arrays have different ranks, prepend 1s to smaller rank
177    2. Arrays are compatible if dimensions are equal or one is 1
178    3. Result shape is element-wise maximum
179    """
180    print("=== Broadcasting Examples ===\n")
181
182    # Example 1: Scalar + array
183    a = np.array([1, 2, 3])
184    b = 5
185    result = a + b
186    print(f"1. Scalar + array: {a} + {b} = {result}\n")
187
188    # Example 2: Row vector + column vector (outer sum)
189    row = np.array([[1, 2, 3]])  # (1, 3)
190    col = np.array([[10], [20], [30]])  # (3, 1)
191    result = row + col  # (3, 3)
192    print(f"2. Row + column (outer sum):")
193    print(f"{result}\n")
194
195    # Example 3: Broadcasting in batch operations
196    batch_data = np.random.randn(32, 10)  # 32 samples, 10 features
197    feature_mean = np.mean(batch_data, axis=0, keepdims=True)  # (1, 10)
198    centered_data = batch_data - feature_mean  # Broadcast (1, 10) to (32, 10)
199    print(f"3. Batch normalization:")
200    print(f"   Data shape: {batch_data.shape}")
201    print(f"   Mean shape: {feature_mean.shape}")
202    print(f"   Centered data shape: {centered_data.shape}")
203    print(f"   Mean after centering: {np.mean(centered_data, axis=0)[:3]} (should be ~0)\n")
204
205    # Example 4: Broadcasting with einsum
206    A = np.random.randn(5, 3)
207    b = np.random.randn(3)
208
209    # Add bias using broadcasting
210    result_broadcast = A + b
211
212    # Add bias using einsum (less natural here, but possible)
213    result_einsum = A + np.einsum('i->i', b)
214
215    print(f"4. Matrix + vector broadcasting:")
216    print(f"   Close match: {np.allclose(result_broadcast, result_einsum)}\n")
217
218
219def numerical_stability():
220    """
221    Demonstrate numerical stability techniques.
222
223    Common issues:
224    - Overflow/underflow in exp()
225    - Log of small numbers
226    - Division by zero
227    """
228    print("=== Numerical Stability Techniques ===\n")
229
230    # Problem 1: Softmax overflow
231    logits = np.array([1000, 1001, 1002])  # Large values
232
233    # Naive softmax (will overflow)
234    try:
235        naive_softmax = np.exp(logits) / np.sum(np.exp(logits))
236        print(f"Naive softmax: {naive_softmax}")
237    except:
238        print("Naive softmax: OVERFLOW!")
239
240    # Stable softmax (subtract max)
241    max_logit = np.max(logits)
242    stable_softmax = np.exp(logits - max_logit) / np.sum(np.exp(logits - max_logit))
243    print(f"Stable softmax: {stable_softmax}")
244    print(f"Sum: {np.sum(stable_softmax):.6f}\n")
245
246    # Problem 2: Log-sum-exp trick
247    x = np.array([1000, 1001, 1002])
248
249    # Naive log(sum(exp(x))) will overflow
250    max_x = np.max(x)
251    log_sum_exp_stable = max_x + np.log(np.sum(np.exp(x - max_x)))
252    print(f"Log-sum-exp (stable): {log_sum_exp_stable:.4f}")
253
254    # Verify with scipy (if available)
255    from scipy.special import logsumexp
256    print(f"Scipy logsumexp: {logsumexp(x):.4f}\n")
257
258    # Problem 3: Numerical precision in small differences
259    a = torch.tensor([1e10, 1.0, 1e-10])
260    b = torch.tensor([1e10, 1.0, 0.0])
261
262    # Use torch.allclose with tolerance
263    print(f"Arrays close (default tol): {torch.allclose(a, b)}")
264    print(f"Arrays close (rtol=1e-5, atol=1e-8): {torch.allclose(a, b, rtol=1e-5, atol=1e-8)}\n")
265
266
267if __name__ == "__main__":
268    print("=" * 60)
269    print("Tensor Operations and Einstein Summation")
270    print("=" * 60)
271    print()
272
273    # Set random seed
274    np.random.seed(42)
275    torch.manual_seed(42)
276
277    # Run demonstrations
278    tensor_basics()
279    print()
280
281    einsum_examples()
282    print()
283
284    attention_with_einsum()
285    print()
286
287    broadcasting_examples()
288    print()
289
290    numerical_stability()
291
292    print("=" * 60)
293    print("Summary:")
294    print("- Einsum provides concise notation for tensor operations")
295    print("- Broadcasting enables efficient element-wise operations")
296    print("- Numerical stability is critical for deep learning")
297    print("- PyTorch and NumPy have similar tensor APIs")
298    print("=" * 60)