from datasets import Dataset, Features, Array3D, Value
import os
import numpy as np
from PIL import Image
# Define image path
image_dir ="/path/to/local_img_dir"# Load and convert images to numpy arrays with resizingdefload_image(image_path):
image = Image.open(image_path).resize((512,512))# Resize to 512x512return np.array(image)# Generate image file paths list
image_paths =[os.path.join(image_dir, f)for f in os.listdir(image_dir)if f.endswith(".png")]# Create dataset, returning a dictionary with keys 'image' and 'filename'defgenerate_dataset():for path in image_paths:yield{"image": load_image(path),"filename": os.path.basename(path)}
dataset = Dataset.from_generator(
generate_dataset,
features=Features({"image": Array3D(dtype="uint8", shape=(512,512,3)),# Updated shape"filename": Value("string")}))# Set batch size to avoid memory overflow
dataset = dataset.with_format("numpy", writer_batch_size=50)# Adjust as necessary# Push dataset to Hugging Face Hub
dataset.push_to_hub("your/data_repo")
from datasets import load_dataset
from PIL import Image
import numpy as np
import os
HF_TOKEN ='hf_yourtoken'# 如果是 private 需要用特定的 token 来鉴权# Load the dataset from Hugging Face Hub
dataset = load_dataset("your/data_repo", token=HF_TOKEN)# Directory where the images will be saved
save_dir ="/local/dir"# Ensure the directory exists
os.makedirs(save_dir, exist_ok=True)# Function to convert numpy array to PIL imagedefconvert_to_image(image_array):return Image.fromarray(np.uint8(image_array))# 用于测试,先下载前 10 张看看,不过 load_dataset 会直接将所有数据(.parquet)都下载下来# Save the first 10 images for i, example inenumerate(dataset['train'].select(range(10))):# Select the first 10 examples
image_array = example['image']# Get the numpy array for the image
image = convert_to_image(image_array)# Convert to PIL image
filename = example['filename']# Get the filename (without extension)# Construct the full file path and save the image as a PNG
image_path = os.path.join(save_dir,f"{filename}.png")
image.save(image_path)print(f"Saved {filename}.png to {save_dir}")