1def kth_smallest(matrix, k):
2 n = len(matrix)
3 low = matrix[0][0]
4 high = matrix[n - 1][n - 1]
5
6 while low < high:
7 mid = low + (high - low) // 2
8
9 # Count elements <= mid
10 r, c = n - 1, 0
11 count = 0
12 while r >= 0 and c < n:
13 if matrix[r][c] <= mid:
14 count += r + 1
15 c += 1
16 else:
17 r -= 1
18
19 # Adjust search range
20 if count >= k:
21 high = mid
22 else:
23 low = mid + 1
24
25 return low