重新安装 timesFM 记录一下过程

1. 安装环境 conda create --name a310 python=3.10.15

(base) root@VM-5-128-ubuntu:/workspace# conda activate a310

2. 安装 pyenv

curl https://pyenv.run | bash

3. 修改 .bash_profile

文件在 root 下

export PYENV_ROOT="$HOME/.pyenv"
[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"
eval "$(pyenv init -)"
export PATH="/root/.local/bin:$PATH"


4.刷新环境变量,这时后才会生效,执行以下代码

source ~/.bash_profile

5. check

(a310) root@VM-5-128-ubuntu:/workspace# pyenv --version
pyenv 2.5.0

6. 安装 poetry

(a310) root@VM-5-128-ubuntu:/workspace# pip install poetry


(a310) root@VM-5-128-ubuntu:/workspace# poetry --version
Poetry (version 1.8.5)

7. 安装 pyenv install 3.10.15

中间出现多次报错

    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/tmp/python-build.20241229151142.3168/Python-3.10.15/python', '-W', 'ignore::DeprecationWarning', '-c', '\nimport runpy\nimport sys\nsys.path = [\'/tmp/tmp8oem7l6p/setuptools-65.5.0-py3-none-any.whl\', \'/tmp/tmp8oem7l6p/pip-23.0.1-py3-none-any.whl\'] + sys.path\nsys.argv[1:] = [\'install\', \'--no-cache-dir\', \'--no-index\', \'--find-links\', \'/tmp/tmp8oem7l6p\', \'--root\', \'/\', \'--upgrade\', \'setuptools\', \'pip\']\nrunpy.run_module("pip", run_name="__main__", alter_sys=True)\n']' returned non-zero exit status 1.
make: *** [Makefile:1280: install] Error 1

 错误分析  zlib 模块缺失

先安装 sudo

apt-get install sudo

手动安装 pip 和 setuptools

wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
pip install setuptools

安装 zlib 库
sudo apt-get update
sudo apt-get install zlib1g-dev

在执行pyenv install 3.10.15

执行 pyenv versions 检查 后 这个命令会列出所有已安装的 Python 版本,你应该能在列表中看到3.10.15

发现仍未正确安装

执行这两个代码

sudo apt-get update

sudo apt-get install -y make build-essential libssl-dev zlib1g-dev libbz2-dev \
> libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
> xz-utils tk-dev libffi-dev liblzma-dev python-openssl

安装完依赖后,再使用pyenv安装 Python 3.10.15:

pyenv install 3.10.15

 执行 pyenv versions 检查

(base) root@VM-5-128-ubuntu:/workspace/timesfm# pyenv versions
* system (set by /root/.pyenv/version)
  3.10.15

再安装 python 3.11

执行 pyenv versions 

(base) root@VM-5-128-ubuntu:/workspace/timesfm# pyenv versions
* system (set by /root/.pyenv/version)
  3.10.15
  3.11.11

 poetry env use 3.10.15

(base) root@VM-5-128-ubuntu:/workspace/timesfm# poetry env use 3.10.15
Creating virtualenv timesfm-p1AFFT58-py3.10 in /root/.cache/pypoetry/virtualenvs
Using virtualenv: /root/.cache/pypoetry/virtualenvs/timesfm-p1AFFT58-py3.10

poetry lock

poetry install -E  pax

(base) root@VM-5-128-ubuntu:/workspace/timesfm# poetry install -E  pax
Installing dependencies from lock file

After than you can run the timesfm under `poetry shell` or do `poetry run python3 ...`

base) root@VM-5-128-ubuntu:/workspace/timesfm# poetry run python3 -m vehicle_fcst --model_path=google/timesfm-1.0-200m

vehicle_fcst.py 代码如下

import os
import gc
import numpy as np
import pandas as pd
from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint, patched_decoder, data_loader
from praxis import pax_fiddle, py_utils, pytypes, base_model, optimizers, schedules, base_hyperparams, base_layer
from paxml import tasks_lib, trainer_lib, checkpoints, learners, partitioning, checkpoint_types
import timesfm
from tqdm import tqdm  # 添加这一行
import jax
from jax import numpy as jnp
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis import base_model
from praxis import optimizers
from praxis import schedules
from praxis import base_hyperparams
from praxis import base_layer
from paxml import tasks_lib
from paxml import trainer_lib
from paxml import checkpoints
from paxml import learners
from paxml import partitioning
from paxml import checkpoint_types


# 设置环境变量
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'


def load_timesfm_model():
    """加载TimesFM预训练模型"""
    timesfm_backend = "gpu"
    tfm = TimesFm(
      hparams=TimesFmHparams(
          backend="gpu",
          per_core_batch_size=32,
          horizon_len=128,
      ),
        checkpoint=TimesFmCheckpoint(
            huggingface_repo_id="google/timesfm-2.0-500m-jax"), #google/timesfm-2.0-500m-jax"),google/timesfm-1.0-200m
    )
    return tfm


def load_dataset(dataset_name):
    """加载数据集"""
    DATA_DICT = {
        "vehicle_sales": {
            "boundaries": [85991, 108028, 126414],
            "data_path": "datasets/output 2024.csv",
            "freq": "D",
        }
    }
    data_path = DATA_DICT[dataset_name]["data_path"]
    freq = DATA_DICT[dataset_name]["freq"]
    int_freq = timesfm.freq_map(freq)
    boundaries = DATA_DICT[dataset_name]["boundaries"]

    data_df = pd.read_csv(open(data_path, "r"))
    ts_cols = ["sales_volume", "sales_price"]
    num_cov_cols = None
    cat_cov_cols = None
    context_len = 512
    pred_len = 96
    num_ts = len(ts_cols)
    batch_size = 4

    dtl = data_loader.TimeSeriesdata(
        data_path=data_path,
        datetime_col="date",
        num_cov_cols=num_cov_cols,
        cat_cov_cols=cat_cov_cols,
        ts_cols=np.array(ts_cols),
        train_range=[0, boundaries[0]],
        val_range=[boundaries[0], boundaries[1]],
        test_range=[boundaries[1], boundaries[2]],
        hist_len=context_len,
        pred_len=pred_len,
        batch_size=num_ts,
        freq=freq,
        normalize=True,
        epoch_len=None,
        holiday=False,
        permute=True,
    )
    train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
    val_batches = dtl.tf_dataset(mode="val", shift=pred_len)
    test_batches = dtl.tf_dataset(mode="test", shift=pred_len)
    return dtl, train_batches, val_batches, test_batches, batch_size, num_ts


def evaluate_pretrained_model(tfm, test_batches):
    """评估预训练模型"""
    mae_losses = []
    for batch in tqdm(test_batches.as_numpy_iterator()):
        past = batch[0]
        actuals = batch[3]
        forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)
        forecasts = forecasts[:, 0:actuals.shape[1]]
        mae_losses.append(np.abs(forecasts - actuals).mean())
    return np.mean(mae_losses)


def build_learner():
    """构建学习器"""
    return pax_fiddle.Config(
        learners.Learner,
        name='learner',
        loss_name='avg_qloss',
        optimizer=optimizers.Adam(
            epsilon=1e-7,
            clip_threshold=1e2,
            learning_rate=1e-2,
            lr_schedule=pax_fiddle.Config(
                schedules.Cosine,
                initial_value=1e-3,
                final_value=1e-4,
                total_steps=40000,
            ),
            ema_decay=0.9999,
        ),
        bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],
    )


def setup_training_task(model, learner):
    """设置训练任务"""
    task_p = tasks_lib.SingleTask(
        name='ts-learn',
        model=model,
        train=tasks_lib.SingleTask.Train(
            learner=learner,
        ),
    )
    task_p.model.ici_mesh_shape = [1, 1, 1]
    task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']
    return task_p


def initialize_model_state(task_p, key, tbatch, batch_size, num_ts):
    """初始化模型状态"""
    key, init_key = jax.random.split(key)
    return trainer_lib.initialize_model_state(
        task_p,
        init_key,
        process_train_batch(tbatch, batch_size, num_ts),
        checkpoint_type=checkpoint_types.CheckpointType.GDA,
    )


def process_train_batch(batch, batch_size, num_ts):
    """处理训练批次数据"""
    past_ts = batch[0].reshape(batch_size * num_ts, -1)
    actual_ts = batch[3].reshape(batch_size * num_ts, -1)
    return py_utils.NestedMap(input_ts=past_ts, actual_ts=actual_ts)


def process_eval_batch(batch):
    """处理评估批次数据"""
    past_ts = batch[0]
    actual_ts = batch[3]
    return py_utils.NestedMap(input_ts=past_ts, actual_ts=actual_ts)


def reshape_batch_for_pmap(batch, num_devices):
    """为pmap重塑批次数据"""
    def _reshape(input_tensor):
        bsize = input_tensor.shape[0]
        residual_shape = list(input_tensor.shape[1:])
        nbsize = bsize // num_devices
        return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)
    return jax.tree.map(_reshape, batch)


def train_model(task_p, replicated_jax_states, train_prng_seed, train_batches,
                val_batches, num_devices, CHECKPOINT_DIR,
                NUM_EPOCHS=100, PATIENCE=5, TRAIN_STEPS_PER_EVAL=1000):
    """训练模型"""
    p_train_step = jax.pmap(lambda states, prng_key, inputs: trainer_lib.train_step_single_learner(task_p, states, prng_key, inputs),
                            axis_name='batch')
    p_eval_step = jax.pmap(lambda states, prng_key, inputs: trainer_lib.eval_step_single_learner(task_p, states, prng_key, inputs),
                            axis_name='batch')

    best_eval_loss = 1e7
    step_count = 0
    patience = 0

    for epoch in range(NUM_EPOCHS):
        print(f"__________________Epoch: {epoch}__________________", flush=True)
        train_its = train_batches.as_numpy_iterator()
        if patience >= PATIENCE:
            print("Early stopping.", flush=True)
            break
        for batch in tqdm(train_its):
            train_losses = []
            if patience >= PATIENCE:
                print("Early stopping.", flush=True)
                break
            tbatch = process_train_batch(batch, batch_size, num_ts)
            tbatch = reshape_batch_for_pmap(tbatch, num_devices)
            replicated_jax_states, step_fun_out = p_train_step(
                replicated_jax_states, train_prng_seed, tbatch
            )
            train_losses.append(step_fun_out.loss[0])
            if step_count % TRAIN_STEPS_PER_EVAL == 0:
                print(
                    f"Train loss at step {step_count}: {np.mean(train_losses)}",
                    flush=True
                )
                train_losses = []
                print("Starting eval.", flush=True)
                val_its = val_batches.as_numpy_iterator()
                eval_losses = []
                for ev_batch in tqdm(val_its):
                    ebatch = process_eval_batch(ev_batch)
                    ebatch = reshape_batch_for_pmap(ebatch, num_devices)
                    _, step_fun_out = p_eval_step(
                        replicated_jax_states, eval_prng_seed, ebatch
                    )
                    eval_losses.append(step_fun_out.loss[0])
                mean_loss = np.mean(eval_losses)
                print(f"Eval loss at step {step_count}: {mean_loss}", flush=True)
                if mean_loss < best_eval_loss or np.isnan(mean_loss):
                    best_eval_loss = mean_loss
                    print("Saving checkpoint.")
                    jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(
                        replicated_jax_states
                    )
                    checkpoints.save_checkpoint(
                        jax_state_for_saving, CHECKPOINT_DIR, overwrite=True
                    )
                    patience = 0
                    del jax_state_for_saving
                    gc.collect()
                else:
                    patience += 1
                    print(f"patience: {patience}")
            step_count += 1


def evaluate_finetuned_model(tfm, test_batches):
    """评估微调后的模型"""
    mae_losses = []
    for batch in tqdm(test_batches.as_numpy_iterator()):
        past = batch[0]
        actuals = batch[3]
        _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])
        if forecasts.ndim > 2:  # 确保有足够的维度来进行索引
            forecasts = forecasts[:, 0:actuals.shape[1], 5]
        else:
            print("Forecasts tensor does not have enough dimensions for the specified index.")
        mae_losses.append(np.abs(forecasts - actuals).mean())
    return np.mean(mae_losses)


if __name__ == "__main__":
    tfm = load_timesfm_model()
    dataset_name = "vehicle_sales"
    dtl, train_batches, val_batches, test_batches, batch_size, num_ts = load_dataset(dataset_name)

    # 评估预训练模型
    mae_pretrained = evaluate_pretrained_model(tfm, test_batches)
    print(f"MAE of pretrained model: {mae_pretrained}")

    model = pax_fiddle.Config(
        patched_decoder.PatchedDecoderFinetuneModel,
        name='patched_decoder_finetune',
        core_layer_tpl=tfm.model_p,
    )
    learner = build_learner()
    task_p = setup_training_task(model, learner)

    DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
    MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])
    num_devices = jax.local_device_count()

    tbatch = next(train_batches.as_numpy_iterator())
    jax_model_states, _ = initialize_model_state(task_p, jax.random.PRNGKey(seed=1234), tbatch, batch_size, num_ts)
    jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
    jax_vars = jax_model_states.mdl_vars
    gc.collect()

    replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)
    train_prng_seed = jax.random.split(jax.random.PRNGKey(seed=1234), num=jax.local_device_count())
    eval_prng_seed = jax.random.split(jax.random.PRNGKey(seed=1234), num=jax.local_device_count())

    CHECKPOINT_DIR = '/home/senrajat_google_com/ettm1_finetune'
    train_model(task_p, replicated_jax_states, train_prng_seed, train_batches,
                val_batches, num_devices, CHECKPOINT_DIR)

    train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)
    print(train_state.step)
    tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']
    tfm.jit_decode()

    mae_finetuned = evaluate_finetuned_model(tfm, test_batches)
    print(f"MAE of finetuned model: {mae_finetuned}")
    reduction_percentage = ((mae_pretrained - mae_finetuned) / mae_pretrained) * 100
    print(f"There is around a {reduction_percentage:.2f}% reduction in MAE from finetuning.")

 output 2024.csv 如下:

 训练过程如下

(timesfm-py3.10) (base) root@VM-0-170-ubuntu:/workspace/timesfm# poetry run python3 -m test --model_path=google/timesfm-1.0-20
TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Fetching 6 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 67468.70it/s]
Multiprocessing context has already been set.
Constructing model weights.
Constructed model weights in 3.09 seconds.
Restoring checkpoint from /root/.cache/huggingface/hub/models--google--timesfm-2.0-500m-jax/snapshots/47dedfcadf2abace1cc96071ddb798cfcd3bfcef/checkpoints.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.

Restored checkpoint in 3.37 seconds.
Jitting decoding.
Jitted decoding in 21.43 seconds.
191it [00:09, 19.43it/s]
MAE of pretrained model: 0.8016693453744966

代码可以运行,MAE of pretrained model: 0.80

这是用 google/timesfm-2.0-500m-jax 模型进行预测的。

       checkpoint=TimesFmCheckpoint( huggingface_repo_id="google/timesfm-1.0-200m"), 
 

Restored checkpoint in 1.43 seconds.
Jitting decoding.
Jitted decoding in 21.61 seconds.
191it [00:11, 16.28it/s]
MAE of pretrained model: 0.4892703932467949
__________________Epoch: 0__________________
0it [00:00, ?it/s]2025-01-02 05:34:56.114821: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.25MiB (rounded to 6553600)requested by op 
2025-01-02 05:34:56.115186: W external/tsl/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E0102 05:34:56.117071   66873 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6553600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  853.58MiB
              constant allocation:         4B
        maybe_live_out allocation:  853.57MiB
     preallocated temp allocation:  244.80MiB
  preallocated temp fragmentation:    19.3KiB (0.01%)
                 total allocation:    1.91GiB
              total fragmentation:  244.80MiB (12.54%)

发现 这里的 MAE of pretrained model: 0.48 ,还更好

然后 out of memory 报错了,看来得买 GPU 时间了

__________________Epoch: 0__________________
0it [00:00, ?it/s]2025-01-02 05:21:02.088349: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.25MiB (rounded to 6553600)requested by op 
2025-01-02 05:21:02.088719: W external/tsl/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E0102 05:21:02.090623   66250 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6553600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  853.58MiB
              constant allocation:         4B
        maybe_live_out allocation:  853.57MiB
     preallocated temp allocation:  244.80MiB
  preallocated temp fragmentation:    19.3KiB (0.01%)
                 total allocation:    1.91GiB
              total fragmentation:  244.80MiB (12.54%)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值