示例:使用Shell脚本微调大型模型
假设我们有一个数据集,要在ResNet50模型基础上微调以进行分类任务。
1. 准备数据和环境
在准备数据时,通常需要将数据集组织成训练集和验证集,并确保数据路径和格式符合模型的输入要求。
# 创建训练数据和验证数据的目录
mkdir data
mkdir data/train
mkdir data/val
# 将数据集拷贝到相应目录(示例中假设数据集已经准备好)
cp path_to_train_data/* data/train/
cp path_to_val_data/* data/val/
2. 编写Shell脚本
创建一个Shell脚本(例如train.sh
),用于设置训练参数、启动训练过程和保存模型。
#!/bin/bash
# 设置训练参数
epochs=10
batch_size=32
learning_rate=0.001
model_dir="saved_models/resnet50_finetuned"
# 训练命令
python train.py \
--epochs $epochs \
--batch_size $batch_size \
--learning_rate $learning_rate \
--model_dir $model_dir \
--train_data_dir "data/train" \
--val_data_dir "data/val"
3. 编写训练脚本(train.py)
在Shell脚本中调用Python脚本(例如train.py
),用于定义模型、训练循环、评估和保存模型。
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
def build_model(num_classes):
base_model = ResNet50(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
return model
def train_model(train_data_dir, val_data_dir, epochs, batch_size, learning_rate, model_dir):
train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(224, 224),
batch_size=batch_size,
class_mode='categorical'
)
val_datagen = ImageDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_directory(
val_data_dir,
target_size=(224, 224),
batch_size=batch_size,
class_mode='categorical'
)
model = build_model(num_classes=len(train_generator.class_indices))
model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_generator,
epochs=epochs,
validation_data=val_generator)
# 保存模型
model.save(model_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs for training')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate for training')
parser.add_argument('--model_dir', type=str, default='saved_models/resnet50_finetuned', help='Directory to save trained model')
parser.add_argument('--train_data_dir', type=str, required=True, help='Directory containing training data')
parser.add_argument('--val_data_dir', type=str, required=True, help='Directory containing validation data')
args = parser.parse_args()
train_model(args.train_data_dir, args.val_data_dir, args.epochs, args.batch_size, args.learning_rate, args.model_dir)
4. 执行训练
通过运行Shell脚本来启动训练过程:
bash train.sh