学习如何对AIS数据集整体概况进行分析,掌握船舶轨迹数据集的基本情况(缺失值、异常值)
学习了解数据变量之间的相互关系、变量与预测值之间的存在关系。
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
import multiprocessing as mp
import os
import pickle
import random
# 把读取所有数据的函数放在单独的python文件中,是为了解决多线程问题在jupyter notebook无法运行的问题
import read_all_data
class Load_Save_Data():
def __init__(self,file_name=None):
self.filename = file_name
def load_data(self,Path=None):
if Path is None:
assert self.filename is not None,"Invalid Path...."
else:
self.filename = Path
with open(self.filename,"wb") as f:
data = pickle.load(f)
return data
def save_data(self,data,path):
if path is None:
assert self.filename is not None,"Invalid path...."
else:
self.filename = path
with open(self.filename,"wb") as f:
pickle.dump(data,f)
定义读取数据的函数
def read_data(Path,Kind=""):
"""
:param Path:待读取数据的存放路径
:param Kind:'train' of 'test'
"""
# 替换成数据存放的路径
filenames = os.listdir(Path)
print("\n@Read Data From"+Path+".........................")
with mp.Pool(processes=mp.cpu_count()) as pool:
data_total = list(tqdm(pool.map(read_all_data.read_train_file if Kind == "train" else
read_all_data.read_test_file,filenames),total=len(filenames)
print("\n@End Read total Data............................")
load_save = Load_Save_Data()
if Kind == "train":
load_save.save_data(data_total,"./data_tmp/total_data.pkl")
return data_total
训练数据读取
# 存放数据的绝对路径
train_path = "D:/code_sea/data/train/hy_round1_train_20200102/"
data_train = read_data(train_path,Kind="train")
data_train = pd.concat(data_train)
# 测试数据读取
# 存放数据的绝对路径
test_path = "D:/code_sea/data/test/hy_round1_testA_20200102/"
data_test = read_data(test_path,Kind="test")
data_test = pd.concat(data_test)
pd.options.display.max_info_rows = 2699639
data_train.info()
data_train.describe([0.01,0.025,0.05,0.5,0.75,0.9,0.99])
data_train.head(3).append(data_train.tail(3))
print(f'There are {data_train.isnull().any().sum()} columns in train dataset with missing values.')
把训练集的所有数据,根据类别存放到不同的数据文件中
def get_diff_data():
Path = "./data_tmp/total_data.pkl"
with open(Path,"rb") as f:
total_data = pickle.load(f)
load_save = Load_Save_Data()
kind_data = ["刺网","围网","拖网"]
file_names = ["ciwang_data.pkl","weiwang_data.pkl","tuowang_data.pkl"]
for i,datax in enumerate(kind_data):
data_type = [data for data in total_data if data["type"].unique()[0] == datax]
load_save.save_data(data_type,"./data_tmp/" + file_names[i])
def get_random_one_traj(type=None):
"""
:param type:"ciwang","weiwang" or "tuowang"
"""
np.random.seed(10)
path = "./data_tmp/"
with open(path + type + ".pkl","rb") as f1:
data = pickle.load(f1)
length = len(data)
index = np.random.choice(length)
return data[index]
# 每个类别中随机三个渔船的轨迹进行可视化
def visualize_three_traj():
fig,axes = plt.subplots(nrows=3,ncols=3,figsize=(20,15))
plt.subplots_adjust(wspace=0.2,hspace=0.2)
# 对于每一个类别,随机选出刺网的三条轨迹进行可视化
lables = ["ciwang","weiwang","tuowang"]
for i,file_type in tqdm(enumerate(["ciwang_data","weiwang_data","tuowang_data"])):
data1, data2, data3 = get_random_three_traj(type=file_type)
for j, datax in enumerate([data1, data2, data3]):
x_data = datax["x"].loc[-1:].values
y_data = datax["y"].loc[-1:].values
axes[i][j - 1].scatter(x_data[0], y_data[0], label="start", c="red", s=10, marker="8")
axes[i][j - 1].plot(x_data, y_data, label=lables[i])
axes[i][j - 1].scatter(x_data[len(x_data) - 1], y_data[len(y_data) - 1], label="end", c
marker="v")
axes[i][j - 1].grid(alpha=2)
axes[i][j - 1].legend(loc="best")
plt.show()
visualize_three_traj()
# 随机选取某条数据,观察x坐标序列和y坐标序列的变化情况
def visualize_one_traj_x_y():
fig,axes = plt.subplots(nrows=2,ncols=1,figsize=(10,8))
plt.subplots_adjust(wspace=0.5,hspace=0.5)
data1 = get_random_one_traj(type="weiwang_data")
x = data1["x"].loc[-1:]
x = x / 10000
y = data1["y"].loc[-1:]
y = y / 10000
arr1 = np.arange(len(x))
arr2 = np.arange(len(y))
axes[0].plot(arr1,x,label="x")
axes[1].plot(arr2,y,label="y")
axes[0].grid(alpha=3)
axes[0].legend(loc="best")
axes[1].grid(alpha=3)
axes[1].legend(loc="best")
plt.show()
visualize_one_traj_x_y()