Code:
#include <bits/stdc++.h>
#define ll long long
#define maxn 100002
using namespace std;
void setIO(string s) {
string in=s+".in";
freopen(in.c_str(),"r",stdin);
}
struct Union {
int p[maxn];
void init() {
for(int i=0;i<maxn;++i) p[i]=i;
}
int find(int x) {
return p[x]==x?x:p[x]=find(p[x]);
}
}tr;
int n,edges,m;
ll val[maxn];
int hd[maxn],to[maxn<<1],nex[maxn<<1],fa[21][maxn],nx[400][maxn];
int dep[maxn],key[maxn];
void addedge(int u,int v) {
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u,int ff) {
dep[u]=dep[ff]+1, fa[0][u]=ff, nx[1][u]=ff, nx[0][u]=u;
for(int i=2;i<=m;++i) nx[i][u]=nx[i-1][ff];
for(int i=1;i<21;++i) fa[i][u]=fa[i-1][fa[i-1][u]];
for(int i=hd[u];i;i=nex[i]) {
int v=to[i];
if(v^ff) dfs(v, u);
}
}
int LCA(int x,int y) {
if(dep[x]^dep[y]) {
if(dep[x] > dep[y]) swap(x,y);
for(int i=20;i>=0;--i) if(dep[fa[i][y]]>=dep[x]) y=fa[i][y];
}
if(x==y) return x;
for(int i=20;i>=0;--i) if(fa[i][x] ^ fa[i][y]) x=fa[i][x],y=fa[i][y];
return fa[0][y];
}
int up(int x,int k) {
if(k<=m) return nx[k][x];
for(int i=20;i>=0;--i) {
if(key[i]<=k) x=fa[i][x], k-=key[i];
if(!k) break;
}
return x;
}
void modify(int x) {
if(val[x]==1) return;
val[x]=sqrt(val[x]);
if(val[x]==1) tr.p[x]=tr.find(fa[0][x]);
}
int jump(int x,int y,int f,int k) {
if(dep[y]-dep[f]>=k) return up(y,k);
return up(x,dep[x]+dep[y]-(dep[f]<<1)-k);
}
int get(int x, int k) {
if (k > m) return up(x, k);
int y = tr.find(fa[0][x]);
return up(y, (k - (dep[x] - dep[y]) % k) % k);
}
void update(int x,int y,int k) {
int f=LCA(x,y), len=dep[x]+dep[y]-(dep[f]<<1);
if(len%k) modify(y),y=jump(x,y,f,len%k),f=LCA(x,y);
while(dep[x]>=dep[f]) modify(x),x=get(x,k);
while(dep[y]>dep[f]) modify(y),y=get(y,k);
}
ll query(int x,int y,int k) {
int f=LCA(x,y),len=dep[x]+dep[y]-(dep[f]<<1);
ll res=0;
if(len%k) {
int a=len%k;
res+=val[y];
// printf("%d %d\n",dep[x]-dep[y],a);
y=jump(x,y,f,len%k);
// y=up(x,11);
f=LCA(x,y);
}
res+=(dep[x]+dep[y]-(dep[f]<<1))/k+1;
while(dep[x]>=dep[f]) res+=val[x]-1,x=get(x,k);
while(dep[y]>dep[f]) res+=val[y]-1,y=get(y,k);
return res;
}
int main() {
// setIO("input");
scanf("%d",&n),m=233;
key[0]=1;
for(int i=1;i<=22;++i) key[i]=key[i-1]*2;
for(int i=1;i<=n;++i) scanf("%lld",&val[i]);
for(int i=1;i<n;++i) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
dfs(1,0);
for(int i=1;i<=n;++i) {
if(val[i]==1) tr.p[i]=fa[0][i];
else tr.p[i]=i;
}
int Q;
scanf("%d",&Q);
for(int i=1;i<=Q;++i) {
int op,x,y,k;
scanf("%d%d%d%d",&op,&x,&y,&k);
// printf("%d %d %d %d\n",i,op,x,y);
if(op==0) update(x,y,k);
else printf("%lld\n",query(x,y,k));
}
return 0;
}