根据 ip 库中的 ip 范围判断 给定 ip 所在地,ip库中的 ip 范围是有序的
ip库
用户 ip 数据
代码
from pyspark.sql import SparkSession
import sys
# 将 ipv4 字符串转换为整数
def ip_transform(ip):
ips = ip.split(".")
ip_num = 0
for i in ips:
ip_num = int(i) | ip_num << 8
return ip_num
def binary_search(ip, square):
start = 0
end = len(square) - 1
while(start != end):
mid = (start + end) // 2
temp = square[mid]
if temp[0] <= ip <= temp[1]:
return temp[2:] # 地区信息
elif ip < temp[0]:
end = mid
elif ip > temp[1]:
start = mid
def main():
spark = SparkSession.builder.appName("ip_identify").getOrCreate()
sc = spark.sparkContext
ip_lib = sc.textFile("/usr/local/big_data/learn_pyspark/ip.lib.txt")
ip_lib = ip_lib.flatMap(lambda x:x.split('\n'))
user_ips = sc.textFile("/usr/local/big_data/learn_pyspark/user_ips.txt")
user_ips = user_ips.flatMap(lambda x:x.split('\n'))
ip_lib = ip_lib.map(lambda x:x.split()).map(lambda item:\
(ip_transform(item[0]), ip_transform(item[1]), *item[2:]))
# 创建广播变量
ip_lib_broadcast = sc.broadcast(ip_lib.collect())
def get_position(ip):
_ip = ip_transform(ip)
broadcast_value = ip_lib_broadcast.value
temp = binary_search(_ip, broadcast_value)
return (ip, *temp)
user_ips_rdd = user_ips.map(lambda x:get_position(x))
for i in user_ips_rdd.collect():
print(i)
sc.stop()
if __name__ == '__main__':
main()