1. 安装环境 conda create --name a310 python=3.10.15
(base) root@VM-5-128-ubuntu:/workspace# conda activate a310
2. 安装 pyenv
curl https://pyenv.run | bash
3. 修改 .bash_profile
文件在 root 下
export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
export PATH="/root/.local/bin:$PATH"
4.刷新环境变量,这时后才会生效,执行以下代码
source ~/.bash_profile
5. check
(a310) root@VM-5-128-ubuntu:/workspace# pyenv --version
pyenv 2.5.0
6. 安装 poetry
(a310) root@VM-5-128-ubuntu:/workspace# pip install poetry
(a310) root@VM-5-128-ubuntu:/workspace# poetry --version
Poetry (version 1.8.5)
7. 安装 pyenv install 3.10.15
中间出现多次报错
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/tmp/python-build.20241229151142.3168/Python-3.10.15/python', '-W', 'ignore::DeprecationWarning', '-c', '\nimport runpy\nimport sys\nsys.path = [\'/tmp/tmp8oem7l6p/setuptools-65.5.0-py3-none-any.whl\', \'/tmp/tmp8oem7l6p/pip-23.0.1-py3-none-any.whl\'] + sys.path\nsys.argv[1:] = [\'install\', \'--no-cache-dir\', \'--no-index\', \'--find-links\', \'/tmp/tmp8oem7l6p\', \'--root\', \'/\', \'--upgrade\', \'setuptools\', \'pip\']\nrunpy.run_module("pip", run_name="__main__", alter_sys=True)\n']' returned non-zero exit status 1.
make: *** [Makefile:1280: install] Error 1
错误分析 zlib 模块缺失
先安装 sudo
apt-get install sudo
手动安装 pip 和 setuptools
wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
pip install setuptools
安装 zlib 库
sudo apt-get update
sudo apt-get install zlib1g-dev
在执行pyenv install 3.10.15
执行 pyenv versions 检查 后 这个命令会列出所有已安装的 Python 版本,你应该能在列表中看到3.10.15
。
发现仍未正确安装
执行这两个代码
sudo apt-get update
sudo apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev \
> libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
> xz-utils tk-dev libffi-dev liblzma-dev python-openssl
安装完依赖后,再使用pyenv
安装 Python 3.10.15:
pyenv install 3.10.15
执行 pyenv versions 检查
(base) root@VM-5-128-ubuntu:/workspace/timesfm# pyenv versions
* system (set by /root/.pyenv/version)
3.10.15
再安装 python 3.11
执行 pyenv versions
(base) root@VM-5-128-ubuntu:/workspace/timesfm# pyenv versions
* system (set by /root/.pyenv/version)
3.10.15
3.11.11
poetry env use 3.10.15
(base) root@VM-5-128-ubuntu:/workspace/timesfm# poetry env use 3.10.15
Creating virtualenv timesfm-p1AFFT58-py3.10 in /root/.cache/pypoetry/virtualenvs
Using virtualenv: /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10
poetry lock
poetry install -E pax
(base) root@VM-5-128-ubuntu:/workspace/timesfm# poetry install -E pax
Installing dependencies from lock file
After than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`
base) root@VM-5-128-ubuntu:/workspace/timesfm# poetry run python3 -m vehicle_fcst --model_path=google/timesfm-1.0-200m
vehicle_fcst.py 代码如下
import os
import gc
import numpy as np
import pandas as pd
from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint, patched_decoder, data_loader
from praxis import pax_fiddle, py_utils, pytypes, base_model, optimizers, schedules, base_hyperparams, base_layer
from paxml import tasks_lib, trainer_lib, checkpoints, learners, partitioning, checkpoint_types
import timesfm
from tqdm import tqdm # 添加这一行
import jax
from jax import numpy as jnp
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis import base_model
from praxis import optimizers
from praxis import schedules
from praxis import base_hyperparams
from praxis import base_layer
from paxml import tasks_lib
from paxml import trainer_lib
from paxml import checkpoints
from paxml import learners
from paxml import partitioning
from paxml import checkpoint_types
# 设置环境变量
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'
def load_timesfm_model():
"""加载TimesFM预训练模型"""
timesfm_backend = "gpu"
tfm = TimesFm(
hparams=TimesFmHparams(
backend="gpu",
per_core_batch_size=32,
horizon_len=128,
),
checkpoint=TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-jax"), #google/timesfm-2.0-500m-jax"),google/timesfm-1.0-200m
)
return tfm
def load_dataset(dataset_name):
"""加载数据集"""
DATA_DICT = {
"vehicle_sales": {
"boundaries": [85991, 108028, 126414],
"data_path": "datasets/output 2024.csv",
"freq": "D",
}
}
data_path = DATA_DICT[dataset_name]["data_path"]
freq = DATA_DICT[dataset_name]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset_name]["boundaries"]
data_df = pd.read_csv(open(data_path, "r"))
ts_cols = ["sales_volume", "sales_price"]
num_cov_cols = None
cat_cov_cols = None
context_len = 512
pred_len = 96
num_ts = len(ts_cols)
batch_size = 4
dtl = data_loader.TimeSeriesdata(
data_path=data_path,
datetime_col="date",
num_cov_cols=num_cov_cols,
cat_cov_cols=cat_cov_cols,
ts_cols=np.array(ts_cols),
train_range=[0, boundaries[0]],
val_range=[boundaries[0], boundaries[1]],
test_range=[boundaries[1], boundaries[2]],
hist_len=context_len,
pred_len=pred_len,
batch_size=num_ts,
freq=freq,
normalize=True,
epoch_len=None,
holiday=False,
permute=True,
)
train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
val_batches = dtl.tf_dataset(mode="val", shift=pred_len)
test_batches = dtl.tf_dataset(mode="test", shift=pred_len)
return dtl, train_batches, val_batches, test_batches, batch_size, num_ts
def evaluate_pretrained_model(tfm, test_batches):
"""评估预训练模型"""
mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):
past = batch[0]
actuals = batch[3]
forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)
forecasts = forecasts[:, 0:actuals.shape[1]]
mae_losses.append(np.abs(forecasts - actuals).mean())
return np.mean(mae_losses)
def build_learner():
"""构建学习器"""
return pax_fiddle.Config(
learners.Learner,
name='learner',
loss_name='avg_qloss',
optimizer=optimizers.Adam(
epsilon=1e-7,
clip_threshold=1e2,
learning_rate=1e-2,
lr_schedule=pax_fiddle.Config(
schedules.Cosine,
initial_value=1e-3,
final_value=1e-4,
total_steps=40000,
),
ema_decay=0.9999,
),
bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],
)
def setup_training_task(model, learner):
"""设置训练任务"""
task_p = tasks_lib.SingleTask(
name='ts-learn',
model=model,
train=tasks_lib.SingleTask.Train(
learner=learner,
),
)
task_p.model.ici_mesh_shape = [1, 1, 1]
task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']
return task_p
def initialize_model_state(task_p, key, tbatch, batch_size, num_ts):
"""初始化模型状态"""
key, init_key = jax.random.split(key)
return trainer_lib.initialize_model_state(
task_p,
init_key,
process_train_batch(tbatch, batch_size, num_ts),
checkpoint_type=checkpoint_types.CheckpointType.GDA,
)
def process_train_batch(batch, batch_size, num_ts):
"""处理训练批次数据"""
past_ts = batch[0].reshape(batch_size * num_ts, -1)
actual_ts = batch[3].reshape(batch_size * num_ts, -1)
return py_utils.NestedMap(input_ts=past_ts, actual_ts=actual_ts)
def process_eval_batch(batch):
"""处理评估批次数据"""
past_ts = batch[0]
actual_ts = batch[3]
return py_utils.NestedMap(input_ts=past_ts, actual_ts=actual_ts)
def reshape_batch_for_pmap(batch, num_devices):
"""为pmap重塑批次数据"""
def _reshape(input_tensor):
bsize = input_tensor.shape[0]
residual_shape = list(input_tensor.shape[1:])
nbsize = bsize // num_devices
return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)
return jax.tree.map(_reshape, batch)
def train_model(task_p, replicated_jax_states, train_prng_seed, train_batches,
val_batches, num_devices, CHECKPOINT_DIR,
NUM_EPOCHS=100, PATIENCE=5, TRAIN_STEPS_PER_EVAL=1000):
"""训练模型"""
p_train_step = jax.pmap(lambda states, prng_key, inputs: trainer_lib.train_step_single_learner(task_p, states, prng_key, inputs),
axis_name='batch')
p_eval_step = jax.pmap(lambda states, prng_key, inputs: trainer_lib.eval_step_single_learner(task_p, states, prng_key, inputs),
axis_name='batch')
best_eval_loss = 1e7
step_count = 0
patience = 0
for epoch in range(NUM_EPOCHS):
print(f"__________________Epoch: {epoch}__________________", flush=True)
train_its = train_batches.as_numpy_iterator()
if patience >= PATIENCE:
print("Early stopping.", flush=True)
break
for batch in tqdm(train_its):
train_losses = []
if patience >= PATIENCE:
print("Early stopping.", flush=True)
break
tbatch = process_train_batch(batch, batch_size, num_ts)
tbatch = reshape_batch_for_pmap(tbatch, num_devices)
replicated_jax_states, step_fun_out = p_train_step(
replicated_jax_states, train_prng_seed, tbatch
)
train_losses.append(step_fun_out.loss[0])
if step_count % TRAIN_STEPS_PER_EVAL == 0:
print(
f"Train loss at step {step_count}: {np.mean(train_losses)}",
flush=True
)
train_losses = []
print("Starting eval.", flush=True)
val_its = val_batches.as_numpy_iterator()
eval_losses = []
for ev_batch in tqdm(val_its):
ebatch = process_eval_batch(ev_batch)
ebatch = reshape_batch_for_pmap(ebatch, num_devices)
_, step_fun_out = p_eval_step(
replicated_jax_states, eval_prng_seed, ebatch
)
eval_losses.append(step_fun_out.loss[0])
mean_loss = np.mean(eval_losses)
print(f"Eval loss at step {step_count}: {mean_loss}", flush=True)
if mean_loss < best_eval_loss or np.isnan(mean_loss):
best_eval_loss = mean_loss
print("Saving checkpoint.")
jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(
replicated_jax_states
)
checkpoints.save_checkpoint(
jax_state_for_saving, CHECKPOINT_DIR, overwrite=True
)
patience = 0
del jax_state_for_saving
gc.collect()
else:
patience += 1
print(f"patience: {patience}")
step_count += 1
def evaluate_finetuned_model(tfm, test_batches):
"""评估微调后的模型"""
mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):
past = batch[0]
actuals = batch[3]
_, forecasts = tfm.forecast(list(past), [0] * past.shape[0])
if forecasts.ndim > 2: # 确保有足够的维度来进行索引
forecasts = forecasts[:, 0:actuals.shape[1], 5]
else:
print("Forecasts tensor does not have enough dimensions for the specified index.")
mae_losses.append(np.abs(forecasts - actuals).mean())
return np.mean(mae_losses)
if __name__ == "__main__":
tfm = load_timesfm_model()
dataset_name = "vehicle_sales"
dtl, train_batches, val_batches, test_batches, batch_size, num_ts = load_dataset(dataset_name)
# 评估预训练模型
mae_pretrained = evaluate_pretrained_model(tfm, test_batches)
print(f"MAE of pretrained model: {mae_pretrained}")
model = pax_fiddle.Config(
patched_decoder.PatchedDecoderFinetuneModel,
name='patched_decoder_finetune',
core_layer_tpl=tfm.model_p,
)
learner = build_learner()
task_p = setup_training_task(model, learner)
DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])
num_devices = jax.local_device_count()
tbatch = next(train_batches.as_numpy_iterator())
jax_model_states, _ = initialize_model_state(task_p, jax.random.PRNGKey(seed=1234), tbatch, batch_size, num_ts)
jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
jax_vars = jax_model_states.mdl_vars
gc.collect()
replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)
train_prng_seed = jax.random.split(jax.random.PRNGKey(seed=1234), num=jax.local_device_count())
eval_prng_seed = jax.random.split(jax.random.PRNGKey(seed=1234), num=jax.local_device_count())
CHECKPOINT_DIR = '/home/senrajat_google_com/ettm1_finetune'
train_model(task_p, replicated_jax_states, train_prng_seed, train_batches,
val_batches, num_devices, CHECKPOINT_DIR)
train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)
print(train_state.step)
tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']
tfm.jit_decode()
mae_finetuned = evaluate_finetuned_model(tfm, test_batches)
print(f"MAE of finetuned model: {mae_finetuned}")
reduction_percentage = ((mae_pretrained - mae_finetuned) / mae_pretrained) * 100
print(f"There is around a {reduction_percentage:.2f}% reduction in MAE from finetuning.")
output 2024.csv 如下:
训练过程如下
(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry run python3 -m test --model_path=google/timesfm-1.0-20
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Fetching 6 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 67468.70it/s]
Multiprocessing context has already been set.
Constructing model weights.
Constructed model weights in 3.09 seconds.
Restoring checkpoint from /root/.cache/huggingface/hub/models--google--timesfm-2.0-500m-jax/snapshots/47dedfcadf2abace1cc96071ddb798cfcd3bfcef/checkpoints.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
Restored checkpoint in 3.37 seconds.
Jitting decoding.
Jitted decoding in 21.43 seconds.
191it [00:09, 19.43it/s]
MAE of pretrained model: 0.8016693453744966
代码可以运行,MAE of pretrained model: 0.80
这是用 google/timesfm-2.0-500m-jax 模型进行预测的。
checkpoint=TimesFmCheckpoint( huggingface_repo_id="google/timesfm-1.0-200m"),
Restored checkpoint in 1.43 seconds.
Jitting decoding.
Jitted decoding in 21.61 seconds.
191it [00:11, 16.28it/s]
MAE of pretrained model: 0.4892703932467949
__________________Epoch: 0__________________
0it [00:00, ?it/s]2025-01-02 05:34:56.114821: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.25MiB (rounded to 6553600)requested by op
2025-01-02 05:34:56.115186: W external/tsl/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E0102 05:34:56.117071 66873 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6553600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 853.58MiB
constant allocation: 4B
maybe_live_out allocation: 853.57MiB
preallocated temp allocation: 244.80MiB
preallocated temp fragmentation: 19.3KiB (0.01%)
total allocation: 1.91GiB
total fragmentation: 244.80MiB (12.54%)
发现 这里的 MAE of pretrained model: 0.48 ,还更好
然后 out of memory 报错了,看来得买 GPU 时间了
__________________Epoch: 0__________________
0it [00:00, ?it/s]2025-01-02 05:21:02.088349: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.25MiB (rounded to 6553600)requested by op
2025-01-02 05:21:02.088719: W external/tsl/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E0102 05:21:02.090623 66250 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6553600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 853.58MiB
constant allocation: 4B
maybe_live_out allocation: 853.57MiB
preallocated temp allocation: 244.80MiB
preallocated temp fragmentation: 19.3KiB (0.01%)
total allocation: 1.91GiB
total fragmentation: 244.80MiB (12.54%)