1"""
2๋ฌธ์์ด ์๊ณ ๋ฆฌ์ฆ (String Algorithms)
3String Matching and Processing
4
5๋ฌธ์์ด ๊ฒ์ ๋ฐ ์ฒ๋ฆฌ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํฉ๋๋ค.
6"""
7
8from typing import List, Tuple
9
10
11# =============================================================================
12# 1. KMP ์๊ณ ๋ฆฌ์ฆ (Knuth-Morris-Pratt)
13# =============================================================================
14
15def kmp_failure(pattern: str) -> List[int]:
16 """
17 KMP ์คํจ ํจ์ (๋ถ๋ถ ์ผ์น ํ
์ด๋ธ) ๊ณ์ฐ
18 ์๊ฐ๋ณต์ก๋: O(m), m = ํจํด ๊ธธ์ด
19 """
20 m = len(pattern)
21 failure = [0] * m
22 j = 0 # ์ด์ ์ต๋ ์ ๋์ฌ ๊ธธ์ด
23
24 for i in range(1, m):
25 while j > 0 and pattern[i] != pattern[j]:
26 j = failure[j - 1]
27
28 if pattern[i] == pattern[j]:
29 j += 1
30 failure[i] = j
31
32 return failure
33
34
35def kmp_search(text: str, pattern: str) -> List[int]:
36 """
37 KMP ๋ฌธ์์ด ๊ฒ์
38 ์๊ฐ๋ณต์ก๋: O(n + m)
39 ๋ฐํ: ํจํด์ด ๋ฐ๊ฒฌ๋ ์์ ์ธ๋ฑ์ค ๋ฆฌ์คํธ
40 """
41 if not pattern:
42 return []
43
44 n, m = len(text), len(pattern)
45 failure = kmp_failure(pattern)
46 matches = []
47 j = 0 # ํจํด ์ธ๋ฑ์ค
48
49 for i in range(n):
50 while j > 0 and text[i] != pattern[j]:
51 j = failure[j - 1]
52
53 if text[i] == pattern[j]:
54 if j == m - 1:
55 matches.append(i - m + 1)
56 j = failure[j]
57 else:
58 j += 1
59
60 return matches
61
62
63# =============================================================================
64# 2. Rabin-Karp ์๊ณ ๋ฆฌ์ฆ
65# =============================================================================
66
67def rabin_karp_search(text: str, pattern: str, mod: int = 10**9 + 7) -> List[int]:
68 """
69 Rabin-Karp ๋ฌธ์์ด ๊ฒ์ (๋กค๋ง ํด์)
70 ์๊ฐ๋ณต์ก๋: ํ๊ท O(n + m), ์ต์
O(nm)
71 """
72 if not pattern or len(pattern) > len(text):
73 return []
74
75 n, m = len(text), len(pattern)
76 base = 256
77 matches = []
78
79 # ํจํด ํด์ ๊ณ์ฐ
80 pattern_hash = 0
81 text_hash = 0
82 h = pow(base, m - 1, mod)
83
84 for i in range(m):
85 pattern_hash = (pattern_hash * base + ord(pattern[i])) % mod
86 text_hash = (text_hash * base + ord(text[i])) % mod
87
88 for i in range(n - m + 1):
89 if pattern_hash == text_hash:
90 # ํด์ ์ถฉ๋ ํ์ธ
91 if text[i:i + m] == pattern:
92 matches.append(i)
93
94 # ๋ค์ ์๋์ฐ ํด์ ๊ณ์ฐ
95 if i < n - m:
96 text_hash = ((text_hash - ord(text[i]) * h) * base + ord(text[i + m])) % mod
97
98 return matches
99
100
101# =============================================================================
102# 3. Z ์๊ณ ๋ฆฌ์ฆ
103# =============================================================================
104
105def z_function(s: str) -> List[int]:
106 """
107 Z ํจ์ ๊ณ์ฐ
108 z[i] = s์ s[i:]์ ์ต์ฅ ๊ณตํต ์ ๋์ฌ ๊ธธ์ด
109 ์๊ฐ๋ณต์ก๋: O(n)
110 """
111 n = len(s)
112 z = [0] * n
113 z[0] = n
114
115 l, r = 0, 0 # Z-box [l, r)
116
117 for i in range(1, n):
118 if i < r:
119 z[i] = min(r - i, z[i - l])
120
121 while i + z[i] < n and s[z[i]] == s[i + z[i]]:
122 z[i] += 1
123
124 if i + z[i] > r:
125 l, r = i, i + z[i]
126
127 return z
128
129
130def z_search(text: str, pattern: str) -> List[int]:
131 """Z ์๊ณ ๋ฆฌ์ฆ์ ์ด์ฉํ ๋ฌธ์์ด ๊ฒ์"""
132 if not pattern:
133 return []
134
135 combined = pattern + "$" + text
136 z = z_function(combined)
137 m = len(pattern)
138
139 return [i - m - 1 for i in range(m + 1, len(combined)) if z[i] == m]
140
141
142# =============================================================================
143# 4. Manacher ์๊ณ ๋ฆฌ์ฆ (์ต์ฅ ํ๋ฌธ ๋ถ๋ถ๋ฌธ์์ด)
144# =============================================================================
145
146def manacher(s: str) -> Tuple[int, int]:
147 """
148 ์ต์ฅ ํ๋ฌธ ๋ถ๋ถ๋ฌธ์์ด ์ฐพ๊ธฐ
149 ์๊ฐ๋ณต์ก๋: O(n)
150 ๋ฐํ: (์์ ์ธ๋ฑ์ค, ๊ธธ์ด)
151 """
152 if not s:
153 return 0, 0
154
155 # ์ ์ฒ๋ฆฌ: ๋ฌธ์ ์ฌ์ด์ # ์ฝ์
156 t = '#' + '#'.join(s) + '#'
157 n = len(t)
158 p = [0] * n # p[i] = i ์ค์ฌ ํ๋ฌธ ๋ฐ์ง๋ฆ
159
160 c, r = 0, 0 # ํ์ฌ ํ๋ฌธ ์ค์ฌ, ์ค๋ฅธ์ชฝ ๊ฒฝ๊ณ
161
162 for i in range(n):
163 if i < r:
164 p[i] = min(r - i, p[2 * c - i])
165
166 # ํ์ฅ ์๋
167 while i - p[i] - 1 >= 0 and i + p[i] + 1 < n and t[i - p[i] - 1] == t[i + p[i] + 1]:
168 p[i] += 1
169
170 # ๊ฒฝ๊ณ ์
๋ฐ์ดํธ
171 if i + p[i] > r:
172 c, r = i, i + p[i]
173
174 # ์ต์ฅ ํ๋ฌธ ์ฐพ๊ธฐ
175 max_len = max(p)
176 center = p.index(max_len)
177
178 # ์๋ณธ ๋ฌธ์์ด์์์ ์์น
179 start = (center - max_len) // 2
180 length = max_len
181
182 return start, length
183
184
185def longest_palindrome(s: str) -> str:
186 """์ต์ฅ ํ๋ฌธ ๋ถ๋ถ๋ฌธ์์ด ๋ฐํ"""
187 start, length = manacher(s)
188 return s[start:start + length]
189
190
191# =============================================================================
192# 5. ์ ๋ฏธ์ฌ ๋ฐฐ์ด (Suffix Array) - ๊ฐ๋จ ๊ตฌํ
193# =============================================================================
194
195def suffix_array(s: str) -> List[int]:
196 """
197 ์ ๋ฏธ์ฌ ๋ฐฐ์ด ์์ฑ (๊ฐ๋จํ O(n logยฒ n) ๊ตฌํ)
198 ๋ฐํ: ์ฌ์ ์ ์ ๋ ฌ๋ ์ ๋ฏธ์ฌ์ ์์ ์ธ๋ฑ์ค
199 """
200 n = len(s)
201 sa = list(range(n))
202 rank = [ord(c) for c in s]
203 tmp = [0] * n
204
205 k = 1
206 while k < n:
207 # (rank[i], rank[i+k]) ๊ธฐ์ค ์ ๋ ฌ
208 def key(i):
209 return (rank[i], rank[i + k] if i + k < n else -1)
210
211 sa.sort(key=key)
212
213 # ์ rank ๊ณ์ฐ
214 tmp[sa[0]] = 0
215 for i in range(1, n):
216 tmp[sa[i]] = tmp[sa[i - 1]]
217 if key(sa[i]) != key(sa[i - 1]):
218 tmp[sa[i]] += 1
219
220 rank = tmp[:]
221 k *= 2
222
223 return sa
224
225
226def lcp_array(s: str, sa: List[int]) -> List[int]:
227 """
228 LCP ๋ฐฐ์ด (Longest Common Prefix)
229 lcp[i] = s[sa[i]:]์ s[sa[i+1]:]์ ์ต์ฅ ๊ณตํต ์ ๋์ฌ ๊ธธ์ด
230 ์๊ฐ๋ณต์ก๋: O(n)
231 """
232 n = len(s)
233 rank = [0] * n
234 for i, idx in enumerate(sa):
235 rank[idx] = i
236
237 lcp = [0] * (n - 1)
238 h = 0
239
240 for i in range(n):
241 if rank[i] > 0:
242 j = sa[rank[i] - 1]
243 while i + h < n and j + h < n and s[i + h] == s[j + h]:
244 h += 1
245 lcp[rank[i] - 1] = h
246 if h > 0:
247 h -= 1
248
249 return lcp
250
251
252# =============================================================================
253# 6. ํธ๋ผ์ด ๊ธฐ๋ฐ ๋ฌธ์์ด ๊ฒ์
254# =============================================================================
255
256class TrieNode:
257 def __init__(self):
258 self.children = {}
259 self.is_end = False
260 self.output = [] # Aho-Corasick์ฉ
261
262
263class AhoCorasick:
264 """
265 Aho-Corasick ์๊ณ ๋ฆฌ์ฆ (๋ค์ค ํจํด ๊ฒ์)
266 ์ ์ฒ๋ฆฌ: O(ฮฃ|patterns|)
267 ๊ฒ์: O(n + m), m = ๋งค์นญ ์
268 """
269
270 def __init__(self, patterns: List[str]):
271 self.root = TrieNode()
272 self.patterns = patterns
273 self._build_trie()
274 self._build_failure()
275
276 def _build_trie(self):
277 for idx, pattern in enumerate(self.patterns):
278 node = self.root
279 for char in pattern:
280 if char not in node.children:
281 node.children[char] = TrieNode()
282 node = node.children[char]
283 node.is_end = True
284 node.output.append(idx)
285
286 def _build_failure(self):
287 from collections import deque
288
289 queue = deque()
290 self.root.fail = self.root
291
292 for child in self.root.children.values():
293 child.fail = self.root
294 queue.append(child)
295
296 while queue:
297 node = queue.popleft()
298
299 for char, child in node.children.items():
300 fail = node.fail
301 while fail != self.root and char not in fail.children:
302 fail = fail.fail
303
304 child.fail = fail.children.get(char, self.root)
305 if child.fail == child:
306 child.fail = self.root
307
308 child.output += child.fail.output
309 queue.append(child)
310
311 def search(self, text: str) -> List[Tuple[int, int]]:
312 """
313 ํ
์คํธ์์ ๋ชจ๋ ํจํด ๊ฒ์
314 ๋ฐํ: [(์์น, ํจํด ์ธ๋ฑ์ค), ...]
315 """
316 results = []
317 node = self.root
318
319 for i, char in enumerate(text):
320 while node != self.root and char not in node.children:
321 node = node.fail
322
323 node = node.children.get(char, self.root)
324
325 for pattern_idx in node.output:
326 pattern = self.patterns[pattern_idx]
327 results.append((i - len(pattern) + 1, pattern_idx))
328
329 return results
330
331
332# =============================================================================
333# 7. ํธ์ง ๊ฑฐ๋ฆฌ (Edit Distance)
334# =============================================================================
335
336def edit_distance(s1: str, s2: str) -> int:
337 """
338 ๋ ๋ฒค์ํ์ธ ๊ฑฐ๋ฆฌ (ํธ์ง ๊ฑฐ๋ฆฌ)
339 ์๊ฐ๋ณต์ก๋: O(mn)
340 ๊ณต๊ฐ๋ณต์ก๋: O(min(m, n)) ์ต์ ํ ๊ฐ๋ฅ
341 """
342 m, n = len(s1), len(s2)
343
344 # ๊ณต๊ฐ ์ต์ ํ: ๋ ํ๋ง ์ฌ์ฉ
345 prev = list(range(n + 1))
346 curr = [0] * (n + 1)
347
348 for i in range(1, m + 1):
349 curr[0] = i
350
351 for j in range(1, n + 1):
352 if s1[i - 1] == s2[j - 1]:
353 curr[j] = prev[j - 1]
354 else:
355 curr[j] = 1 + min(prev[j - 1], prev[j], curr[j - 1])
356
357 prev, curr = curr, prev
358
359 return prev[n]
360
361
362# =============================================================================
363# 8. ๋ฌธ์์ด ํด์ฑ
364# =============================================================================
365
366class StringHash:
367 """
368 ๋คํญ์ ๋กค๋ง ํด์
369 ์ถฉ๋์ ์ค์ด๊ธฐ ์ํด ๋ ๊ฐ์ ํด์ ์ฌ์ฉ (double hashing)
370 """
371
372 def __init__(self, s: str):
373 self.s = s
374 self.n = len(s)
375 self.MOD1 = 10**9 + 7
376 self.MOD2 = 10**9 + 9
377 self.BASE1 = 31
378 self.BASE2 = 37
379
380 self.hash1 = [0] * (self.n + 1)
381 self.hash2 = [0] * (self.n + 1)
382 self.pow1 = [1] * (self.n + 1)
383 self.pow2 = [1] * (self.n + 1)
384
385 for i in range(self.n):
386 self.hash1[i + 1] = (self.hash1[i] * self.BASE1 + ord(s[i])) % self.MOD1
387 self.hash2[i + 1] = (self.hash2[i] * self.BASE2 + ord(s[i])) % self.MOD2
388 self.pow1[i + 1] = self.pow1[i] * self.BASE1 % self.MOD1
389 self.pow2[i + 1] = self.pow2[i] * self.BASE2 % self.MOD2
390
391 def get_hash(self, l: int, r: int) -> Tuple[int, int]:
392 """s[l:r]์ ํด์ ๊ฐ (0-indexed, ๋ฐ์ด๋ฆฐ ๊ตฌ๊ฐ)"""
393 h1 = (self.hash1[r] - self.hash1[l] * self.pow1[r - l]) % self.MOD1
394 h2 = (self.hash2[r] - self.hash2[l] * self.pow2[r - l]) % self.MOD2
395 return (h1, h2)
396
397 def is_equal(self, l1: int, r1: int, l2: int, r2: int) -> bool:
398 """๋ ๋ถ๋ถ๋ฌธ์์ด์ด ๊ฐ์์ง ํ์ธ"""
399 if r1 - l1 != r2 - l2:
400 return False
401 return self.get_hash(l1, r1) == self.get_hash(l2, r2)
402
403
404# =============================================================================
405# ํ
์คํธ
406# =============================================================================
407
408def main():
409 print("=" * 60)
410 print("๋ฌธ์์ด ์๊ณ ๋ฆฌ์ฆ (String Algorithms) ์์ ")
411 print("=" * 60)
412
413 # 1. KMP
414 print("\n[1] KMP ์๊ณ ๋ฆฌ์ฆ")
415 text = "ABABDABACDABABCABAB"
416 pattern = "ABABCABAB"
417 failure = kmp_failure(pattern)
418 matches = kmp_search(text, pattern)
419 print(f" ํ
์คํธ: {text}")
420 print(f" ํจํด: {pattern}")
421 print(f" ์คํจ ํจ์: {failure}")
422 print(f" ๋งค์นญ ์์น: {matches}")
423
424 # 2. Rabin-Karp
425 print("\n[2] Rabin-Karp ์๊ณ ๋ฆฌ์ฆ")
426 matches = rabin_karp_search(text, pattern)
427 print(f" ๋งค์นญ ์์น: {matches}")
428
429 # 3. Z ์๊ณ ๋ฆฌ์ฆ
430 print("\n[3] Z ์๊ณ ๋ฆฌ์ฆ")
431 s = "aabxaab"
432 z = z_function(s)
433 print(f" ๋ฌธ์์ด: {s}")
434 print(f" Z ๋ฐฐ์ด: {z}")
435 matches = z_search(text, pattern)
436 print(f" ๊ฒ์ ๊ฒฐ๊ณผ: {matches}")
437
438 # 4. Manacher
439 print("\n[4] Manacher ์๊ณ ๋ฆฌ์ฆ (์ต์ฅ ํ๋ฌธ)")
440 s = "babad"
441 palindrome = longest_palindrome(s)
442 print(f" ๋ฌธ์์ด: {s}")
443 print(f" ์ต์ฅ ํ๋ฌธ: {palindrome}")
444
445 s2 = "abacdfgdcaba"
446 palindrome2 = longest_palindrome(s2)
447 print(f" ๋ฌธ์์ด: {s2}")
448 print(f" ์ต์ฅ ํ๋ฌธ: {palindrome2}")
449
450 # 5. ์ ๋ฏธ์ฌ ๋ฐฐ์ด
451 print("\n[5] ์ ๋ฏธ์ฌ ๋ฐฐ์ด")
452 s = "banana"
453 sa = suffix_array(s)
454 lcp = lcp_array(s, sa)
455 print(f" ๋ฌธ์์ด: {s}")
456 print(f" ์ ๋ฏธ์ฌ ๋ฐฐ์ด: {sa}")
457 print(" ์ ๋ฏธ์ฌ๋ค:")
458 for i in sa:
459 print(f" {i}: {s[i:]}")
460 print(f" LCP ๋ฐฐ์ด: {lcp}")
461
462 # 6. Aho-Corasick
463 print("\n[6] Aho-Corasick (๋ค์ค ํจํด)")
464 patterns = ["he", "she", "his", "hers"]
465 text = "ahishers"
466 ac = AhoCorasick(patterns)
467 results = ac.search(text)
468 print(f" ํจํด: {patterns}")
469 print(f" ํ
์คํธ: {text}")
470 print(" ๋งค์นญ:")
471 for pos, idx in results:
472 print(f" ์์น {pos}: '{patterns[idx]}'")
473
474 # 7. ํธ์ง ๊ฑฐ๋ฆฌ
475 print("\n[7] ํธ์ง ๊ฑฐ๋ฆฌ")
476 s1, s2 = "kitten", "sitting"
477 dist = edit_distance(s1, s2)
478 print(f" '{s1}' โ '{s2}'")
479 print(f" ํธ์ง ๊ฑฐ๋ฆฌ: {dist}")
480
481 # 8. ๋ฌธ์์ด ํด์ฑ
482 print("\n[8] ๋ฌธ์์ด ํด์ฑ")
483 s = "abcabc"
484 sh = StringHash(s)
485 print(f" ๋ฌธ์์ด: {s}")
486 print(f" hash(0:3) = {sh.get_hash(0, 3)}")
487 print(f" hash(3:6) = {sh.get_hash(3, 6)}")
488 print(f" s[0:3] == s[3:6]: {sh.is_equal(0, 3, 3, 6)}")
489
490 print("\n" + "=" * 60)
491
492
493if __name__ == "__main__":
494 main()