DeepForest 树检测
在原本安装好 cudatoolkit 11.3.1
对应版本pytorch以及python version 3.9
环境中,执行:
2022-01-26 22:23
conda install deepforest albumentations -c conda-forge
该命令下载了 deepforest-1.2.0
。
使用新数据集重新训练模型 [官方给出demo]
#load the modules
import os
import time
import cv2
import numpy as np
from deepforest import main
from deepforest import get_data
from deepforest import utilities
from deepforest import preprocess
from PIL import Image
import rasterio
import matplotlib.pyplot as plt
# convert hand annotations from xml into retinanet format
# The get_data function is only needed when fetching sample package data
YELL_xml = get_data("2019_YELL_2_528000_4978000_image_crop2.xml")
annotation = utilities.xml_to_annotations(YELL_xml)
annotation.head()
# load the image file corresponding to the annotaion file
YELL_train = get_data("2019_YELL_2_528000_4978000_image_crop2.png")
image_path = os.path.dirname(YELL_train)
# Write converted dataframe to file. Saved alongside the images
annotation.to_csv(os.path.join(image_path,"train_example.csv"), index=False)
# prepare training data and valid data
#Find annotation path
annotation_path = os.path.join(image_path,"train_example.csv")
# crop images will save in a newly created directory
# os.mkdir(os.getcwd(),'train_data_folder')
crop_dir = os.path.join(os.getcwd(),'train_data_folder')
train_annotations= preprocess.split_raster(path_to_raster=YELL_train,
annotations_file=annotation_path,
base_dir=crop_dir,
patch_size=400,
patch_overlap=0.05)
# Split image crops into training and test. Normally these would be different tiles! Just as an example.
image_paths = train_annotations.image_path.unique()
# split 25% validation annotation
valid_paths = np.random.choice(image_paths, int(len(image_paths)*0.25) )
valid_annotations = train_annotations.loc[train_annotations.image_path.isin(valid_paths)]
train_annotations = train_annotations.loc[~train_annotations.image_path.isin(valid_paths)]
# View output
train_annotations.head()
print("There are {} training crown annotations".format(train_annotations.shape[0]))
print("There are {} test crown annotations".format(valid_annotations.shape[0]))
# save to file and create the file dir
annotations_file= os.path.join(crop_dir,"train.csv")
validation_file= os.path.join(crop_dir,"valid.csv")
# Write window annotations file without a header row, same location as the "base_dir" above.
train_annotations.to_csv(annotations_file,index=False)
valid_annotations.to_csv(validation_file,index=False)
# print(annotations_file)
# initial the model and change the corresponding config file
m = main.deepforest()
m.config['gpus'] = '-1' #move to GPU and use all the GPU resources
m.config["train"]["csv_file"] = annotations_file
m.config["train"]["root_dir"] = os.path.dirname(annotations_file)
m.config["score_thresh"] = 0.4
m.config["train"]['epochs'] = 2
m.config["validation"]["csv_file"] = validation_file
m.config["validation"]["root_dir"] = os.path.dirname(validation_file)
# create a pytorch lighting trainer used to training
m.create_trainer()
# load the lastest release model
m.use_release()
print("data ok")
start_time = time.time()
m.trainer.fit(m)
print(f"--- Training on GPU: {(time.time() - start_time):.2f} seconds ---")
print()
print(annotations_file)
print()
# annotations_file='/home/pikapikaq/Desktop/TreeHeight/test/1.png'
# annotations_file='/home/pikapikaq/anaconda3/envs/workspace/lib/python3.9/site-packages/deepforest/data/OSBS_029.csv'
# save the prediction result to a prediction folder
save_dir = os.path.join(os.getcwd(),'pred_result')
try:
os.mkdir(save_dir)
except FileExistsError:
pass
results = m.evaluate(annotations_file, os.path.dirname(annotations_file),iou_threshold = 0.4, savedir= save_dir)
# csv_file = '/home/pikapikaq/Desktop/TreeHeight/DeepTree/train_data_folder/2019_YELL_2_528000_4978000_image_crop2_5.png'#'/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg' #'/home/pikapikaq/anaconda3/envs/workspace/lib/python3.9/site-packages/deepforest/data/OSBS_029.tif'
# img=cv2.imread(csv_file)
# img = img.astype(np.float32)/255
# #print(img.dtype)
# # print(img)
# df = m.predict_image(image=img,return_plot=True)#root_dir = os.path.dirname(csv_file))
# print(df)
# csv_file = '/home/pikapikaq/anaconda3/envs/workspace/lib/python3.9/site-packages/deepforest/data/OSBS_029.csv'
# df = m.predict_file(csv_file, root_dir = os.path.dirname(csv_file))
# print(df)
# csv_file = '/home/pikapikaq/Desktop/TreeHeight/DeepTree/pic/OSBS_029.tif' #'/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg'
# img=Image.open(csv_file)
# img_arr = np.array(img)
# img_arr = img_arr.astype(np.float32)/255
# print(img_arr.shape) # uint8
# print(img_arr.dtype)
# df = m.predict_image(image=img_arr,return_plot=True) #
# print(df)
# raster = '/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg'# get_data("2019_YELL_2_528000_4978000_image_crop2.png")
# # '/home/pikapikaq/Desktop/TreeHeight/DeepTree/data/1.jpeg'
# src = rasterio.open(raster)
# #s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
# print(src.read().shape) # #(3, 2472, 2299)
# predicted_boxes =m.predict_tile(raster_path = raster,
# patch_size = 300,
# patch_overlap = 0.5,
# return_plot = True)
# plt.imshow(predicted_boxes[:,:,::-1])
# plt.show()
# print(predicted_boxes)
运行Demo时 出现报错:
KeyError Traceback (most recent call last)
[<ipython-input-42-0501e4367b85>](https://localhost:8080/#) in <module>()
1 start_time = time.time()
2
----> 3 m.trainer.fit(m)
4
5 print(f"--- Training on CPU: {(time.time() - start_time):.2f} seconds ---")
16 frames
[/usr/local/lib/python3.7/dist-packages/deepforest/main.py](https://localhost:8080/#) in load_dataset(self, csv_file, root_dir, augment, shuffle, batch_size, train)
167 transforms=self.transforms(augment=augment),
168 label_dict=self.label_dict,
--> 169 preload_images=self.config["train"]["preload_images"])
170
171 data_loader = torch.utils.data.DataLoader(
KeyError: 'preload_images'
通过讨论区get到,需要更新到最新的版本,执行:
pip install deepforest --upgrade
The version 1.2.1
has no error.