import scipy
from scipy.integrate import odeint
from scipy.optimize import minimize
from sklearn.metrics import mean_absolute_error, mean_squared_error
class SIR(object):
def __init__(self, r_0, t_i, dt, init_S, init_I, init_R):
self.r_0 = r_0
self.t_i = t_i
self.dt = dt
self.init_state_list = [init_S, init_I, init_R]
def _get_param(self, params, key, default, t=None):
"""パラメータの時間変化に対応する
"""
if isinstance(params, dict):
if key in list(params.keys()):
param = params[key]
if isinstance(param, list):
return param[np.clip(int(t / self.dt), 0, len(param)-1)]
elif isinstance(param, np.ndarray):
return param
else:
return default
else:
return default
else:
return default
def _ode(self, state_list, t=None, params=None):
"""連立微分方程式を定義する
"""
r_0 = self._get_param(params, 'r_0', self.r_0, t=t)
t_i = self._get_param(params, 't_i', self.t_i, t=t)
S, I, R = state_list
N = S + I + R
dstate_dt = list()
dstate_dt[0] = - (r_0 / t_i) * (I / N) * S
dstate_dt[1] = (r_0 / t_i) * (I / N) * S - I / t_i
dstate_dt[2] = I / t_i
return dstate_dt
def solve_ode(self, len_days=365, params=None):
"""微分方程式を解く
"""
t = np.linspace(0, len_days, int(len_days / self.dt), endpoint=False)
args = (params,) if params else ()
return odeint(self._ode, self.init_state_list, t, args=args)
class CustomizedSEIRD(SIR):
def __init__(self, r_0=None, t_e=6.0, t_i=2.4, n_i_j=70000, n_j_i=40000, f=0.0001, dt=1,
init_S=126800000, init_E=0, init_I=1, init_R=0, init_D=0):
self.r_0 = r_0
self.t_e = t_e
self.t_i = t_i
self.n_i_j = n_i_j
self.n_j_i = n_j_i
self.f = f
self.dt = dt
self.init_state_list = [init_S, init_E, init_I, init_R, init_D]
def _ode(self, state_list, t=None, params=None):
"""連立微分方程式を定義する
"""
r_0 = self._get_param(params, 'r_0', self.r_0, t=t)
t_e = self._get_param(params, 't_e', self.t_e, t=t)
t_i = self._get_param(params, 't_i', self.t_i, t=t)
n_i_j = self._get_param(params, 'n_i_j', self.n_i_j, t=t)
n_j_i = self._get_param(params, 'n_j_i', self.n_j_i, t=t)
f = self._get_param(params, 'f', self.f)
S, E, I, R, D = state_list
N = S + E + I + R
dstate_dt = list()
dstate_dt.append(- (r_0 / t_i) * (I / N) * S - n_j_i + (1 - ((E + I) / N)) * n_i_j)
dstate_dt.append((r_0 / t_i) * (I / N) * S - E / t_e + ((E + I) / N) * n_i_j)
dstate_dt.append(E / t_e - I / t_i)
dstate_dt.append((1 - f) * I / t_i)
dstate_dt.append(f * I / t_i)
return dstate_dt
def _calc_neg_log_likelihood_r0(self, r_0, X):
"""対数尤度(R_0に関係ある部分のみ)を計算する
"""
solution = self.solve_ode(len_days=len(X), params=dict(r_0=r_0))
lambda_arr = solution[int(1/self.dt)-1::int(1/self.dt), 2] # I
return - np.sum(- lambda_arr + X * np.log(lambda_arr))
def _calc_error(self, r_0, X, metric=mean_absolute_error):
"""平均絶対誤差、平均二乗誤差を計算する
"""
solution = self.solve_ode(len_days=len(X), params=dict(r_0=r_0))
e_arr = solution[int(1/self.dt)-1::int(1/self.dt), 2] # I
return metric(X, e_arr)
def exec_point_estimation(self, init_r_0, X, project='mle'):
"""パラメータを点推定する
"""
if project == 'mle':
result = minimize(self._calc_neg_log_likelihood_r0, init_r_0, args=(X,), method='Nelder-Mead')
elif project == 'lad':
result = minimize(self._calc_error, init_r_0, args=(X, mean_absolute_error), method='Nelder-Mead')
elif project == 'ls':
result = minimize(self._calc_error, init_r_0, args=(X, mean_squared_error), method='Nelder-Mead')
else:
print(f'Invalid project: {project}')
return None
if self.r_0 is None:
self.r_0 = result.x[0]
return result
def exec_map(self):
"""MAP推定
"""
pass
def exec_mcmc(self):
"""MCMC法
"""
pass