Plasticity of Mohr-Coulomb with apex-smoothing#
This tutorial aims to demonstrate how modern automatic or algorithmic differentiation (AD)
techniques may be used to define a complex constitutive model demanding a lot of
by-hand differentiation. In particular, we implement the non-associative
plasticity model of Mohr-Coulomb with apex-smoothing applied to a slope
stability problem for soil. We use the JAX package to define constitutive
relations including the differentiation of certain terms and
FEMExternalOperator
framework to incorporate this model into a weak
formulation within UFL.
The tutorial is based on the limit analysis within semi-definite programming framework, where the plasticity model was replaced by the MFront/TFEL implementation of the Mohr-Coulomb elastoplastic model with apex smoothing.
Problem formulation#
We solve a slope stability problem of a soil domain \(\Omega\) represented by a parallelepiped \([0; L] \times [0; W] \times [0; H]\) with homogeneous Dirichlet boundary conditions for the displacement field \(\boldsymbol{u} = \boldsymbol{0}\) on the right side \(x = L\) and the bottom one \(z = 0\). The loading consists of a gravitational body force \(\boldsymbol{q}=[0, 0, -\gamma]^T\) with \(\gamma\) being the soil self-weight. The solution of the problem is to find the collapse load \(q_\text{lim}\), for which we know an analytical solution in the plane-strain case for the standard Mohr-Coulomb criterion. We follow the same Mandel-Voigt notation as in the von Mises plasticity tutorial but in 3D.
If \(V\) is a functional space of admissible displacement fields, then we can write out a weak formulation of the problem:
Find \(\boldsymbol{u} \in V\) such that
where \(\boldsymbol{\sigma}\) is an external operator representing the stress tensor.
Note
Although the tutorial shows the implementation of the Mohr-Coulomb model, it is quite general to be adapted to a wide rage of plasticity models that may be defined through a yield surface and a plastic potential.
Implementation#
Preamble#
from mpi4py import MPI
from petsc4py import PETSc
import jax
import jax.lax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pyvista
from mpltools import annotation # for slope markers
from solvers import LinearProblem
from utilities import find_cell_by_point
import basix
import dolfinx.plot as plot
import ufl
from dolfinx import common, default_scalar_type, fem, mesh
from dolfinx_external_operator import (
FEMExternalOperator,
evaluate_external_operators,
evaluate_operands,
replace_external_operators,
)
jax.config.update("jax_enable_x64", True)
Here we define geometrical and material parameters of the problem as well as some useful constants.
E = 6778 # [MPa] Young modulus
nu = 0.25 # [-] Poisson ratio
c = 3.45 # [MPa] cohesion
phi = 30 * np.pi / 180 # [rad] friction angle
psi = 30 * np.pi / 180 # [rad] dilatancy angle
theta_T = 26 * np.pi / 180 # [rad] transition angle as defined by Abbo and Sloan
a = 0.26 * c / np.tan(phi) # [MPa] tension cuff-off parameter
L, H = (1.2, 1.0)
Nx, Ny = (50, 50)
gamma = 1.0
domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([L, H])], [Nx, Ny])
k_u = 2
gdim = domain.topology.dim
V = fem.functionspace(domain, ("Lagrange", k_u, (gdim,)))
# Boundary conditions
def on_right(x):
return np.isclose(x[0], L)
def on_bottom(x):
return np.isclose(x[1], 0.0)
bottom_dofs = fem.locate_dofs_geometrical(V, on_bottom)
right_dofs = fem.locate_dofs_geometrical(V, on_right)
bcs = [
fem.dirichletbc(np.array([0.0, 0.0], dtype=PETSc.ScalarType), bottom_dofs, V),
fem.dirichletbc(np.array([0.0, 0.0], dtype=PETSc.ScalarType), right_dofs, V),
]
def epsilon(v):
grad_v = ufl.grad(v)
return ufl.as_vector(
[
grad_v[0, 0],
grad_v[1, 1],
0.0,
np.sqrt(2.0) * 0.5 * (grad_v[0, 1] + grad_v[1, 0]),
]
)
k_stress = 2 * (k_u - 1)
dx = ufl.Measure(
"dx",
domain=domain,
metadata={"quadrature_degree": k_stress, "quadrature_scheme": "default"},
)
stress_dim = 2 * gdim
S_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=k_stress, value_shape=(stress_dim,))
S = fem.functionspace(domain, S_element)
Du = fem.Function(V, name="Du")
u = fem.Function(V, name="Total_displacement")
du = fem.Function(V, name="du")
v = ufl.TrialFunction(V)
u_ = ufl.TestFunction(V)
sigma = FEMExternalOperator(epsilon(Du), function_space=S)
sigma_n = fem.Function(S, name="sigma_n")
Defining plasticity model and external operator#
The constitutive model of the soil is described by a non-associative plasticity law without hardening that is defined by the Mohr-Coulomb yield surface \(f\) and the plastic potential \(g\). Both quantities may be expressed through the following function \(h\)
where \(\phi\) and \(\psi\) are friction and dilatancy angles, \(c\) is a cohesion, \(I_1(\boldsymbol{\sigma}) = \mathrm{tr} \boldsymbol{\sigma}\) is the first invariant of the stress tensor and \(J_2(\boldsymbol{\sigma}) = \frac{1}{2}\boldsymbol{s} \cdot \boldsymbol{s}\) is the second invariant of the deviatoric part of the stress tensor. The expression of the coefficient \(K(\alpha)\) may be found in the MFront/TFEL implementation of this plastic model.
During the plastic loading the stress-strain state of the solid must satisfy the following system of nonlinear equations
By introducing the residual vector \(\boldsymbol{r} = [\boldsymbol{r}_{g}^T, r_f]^T\) and its argument vector \(\boldsymbol{x} = [\boldsymbol{\sigma}_{n+1}^T, \Delta\lambda]^T\) we solve the following nonlinear equation:
To solve this equation we apply the Newton method and introduce the Jacobian of the residual vector \(\boldsymbol{j} = \frac{\mathrm{d} \boldsymbol{r}}{\mathrm{d} \boldsymbol{x}}\). Thus we solve the following linear system at each quadrature point for the plastic phase
During the elastic loading, we consider a trivial system of equations
The algorithm solving the systems (5)–(6) is called the return-mapping procedure and the solution defines the return-mapping correction of the stress tensor. By implementation of the external operator \(\boldsymbol{\sigma}\) we mean the implementation of this algorithmic procedure.
The automatic differentiation tools of the JAX library are applied to calculate the derivatives \(\frac{\mathrm{d} g}{\mathrm{d}\boldsymbol{\sigma}}, \frac{\mathrm{d} \boldsymbol{r}}{\mathrm{d} \boldsymbol{x}}\) as well as the stress tensor derivative or the consistent tangent stiffness matrix \(\boldsymbol{C}_\text{tang} = \frac{\mathrm{d}\boldsymbol{\sigma}}{\mathrm{d}\boldsymbol{\varepsilon}}\).
Defining yield surface and plastic potential#
First of all, we define supplementary functions that help us to express the
yield surface \(f\) and the plastic potential \(g\). In the following definitions,
we use built-in functions of the JAX package, in particular, the conditional
primitive jax.lax.cond
. It is necessary for the correct work of the AD tool
and just-in-time compilation. For more details, please, visit the JAX
documentation.
def J3(s):
return s[2] * (s[0] * s[1] - s[3] * s[3] / 2.0)
def J2(s):
return 0.5 * jnp.vdot(s, s)
def theta(s):
J2_ = J2(s)
arg = -(3.0 * np.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2_ * J2_ * J2_))
arg = jnp.clip(arg, -1.0, 1.0)
theta = 1.0 / 3.0 * jnp.arcsin(arg)
return theta
def sign(x):
return jax.lax.cond(x < 0.0, lambda x: -1, lambda x: 1, x)
def coeff1(theta, angle):
return np.cos(theta_T) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.sin(theta_T)
def coeff2(theta, angle):
return sign(theta) * np.sin(theta_T) + (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.cos(theta_T)
coeff3 = 18.0 * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T)
def C(theta, angle):
return (
-np.cos(3.0 * theta_T) * coeff1(theta, angle) - 3.0 * sign(theta) * np.sin(3.0 * theta_T) * coeff2(theta, angle)
) / coeff3
def B(theta, angle):
return (
sign(theta) * np.sin(6.0 * theta_T) * coeff1(theta, angle) - 6.0 * np.cos(6.0 * theta_T) * coeff2(theta, angle)
) / coeff3
def A(theta, angle):
return (
-(1.0 / np.sqrt(3.0)) * np.sin(angle) * sign(theta) * np.sin(theta_T)
- B(theta, angle) * sign(theta) * np.sin(3 * theta_T)
- C(theta, angle) * np.sin(3.0 * theta_T) * np.sin(3.0 * theta_T)
+ np.cos(theta_T)
)
def K(theta, angle):
def K_false(theta):
return jnp.cos(theta) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * jnp.sin(theta)
def K_true(theta):
return (
A(theta, angle)
+ B(theta, angle) * jnp.sin(3.0 * theta)
+ C(theta, angle) * jnp.sin(3.0 * theta) * jnp.sin(3.0 * theta)
)
return jax.lax.cond(jnp.abs(theta) > theta_T, K_true, K_false, theta)
def a_g(angle):
return a * np.tan(phi) / np.tan(angle)
dev = np.array(
[
[2.0 / 3.0, -1.0 / 3.0, -1.0 / 3.0, 0.0],
[-1.0 / 3.0, 2.0 / 3.0, -1.0 / 3.0, 0.0],
[-1.0 / 3.0, -1.0 / 3.0, 2.0 / 3.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
dtype=PETSc.ScalarType,
)
tr = np.array([1.0, 1.0, 1.0, 0.0], dtype=PETSc.ScalarType)
def surface(sigma_local, angle):
s = dev @ sigma_local
I1 = tr @ sigma_local
theta_ = theta(s)
return (
(I1 / 3.0 * np.sin(angle))
+ jnp.sqrt(
J2(s) * K(theta_, angle) * K(theta_, angle) + a_g(angle) * a_g(angle) * np.sin(angle) * np.sin(angle)
)
- c * np.cos(angle)
)
By picking up an appropriate angle we define the yield surface \(f\) and the plastic potential \(g\).
def f(sigma_local):
return surface(sigma_local, phi)
def g(sigma_local):
return surface(sigma_local, psi)
dgdsigma = jax.jacfwd(g)
Solving constitutive equations#
In this section, we define the constitutive model by solving the systems
(5)–(6). They must be solved at each Gauss point, so we
apply the Newton method, implement the whole algorithm locally and then
vectorize the final result using jax.vmap
.
In the following cell, we define locally the residual \(\boldsymbol{r}\) and
its jacobian drdx
.
lmbda = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))
mu = E / (2.0 * (1.0 + nu))
C_elas = np.array(
[
[lmbda + 2 * mu, lmbda, lmbda, 0],
[lmbda, lmbda + 2 * mu, lmbda, 0],
[lmbda, lmbda, lmbda + 2 * mu, 0],
[0, 0, 0, 2 * mu],
],
dtype=PETSc.ScalarType,
)
S_elas = np.linalg.inv(C_elas)
ZERO_VECTOR = np.zeros(stress_dim, dtype=PETSc.ScalarType)
def deps_p(sigma_local, dlambda, deps_local, sigma_n_local):
sigma_elas_local = sigma_n_local + C_elas @ deps_local
yielding = f(sigma_elas_local)
def deps_p_elastic(sigma_local, dlambda):
return ZERO_VECTOR
def deps_p_plastic(sigma_local, dlambda):
return dlambda * dgdsigma(sigma_local)
return jax.lax.cond(yielding <= 0.0, deps_p_elastic, deps_p_plastic, sigma_local, dlambda)
def r_g(sigma_local, dlambda, deps_local, sigma_n_local):
deps_p_local = deps_p(sigma_local, dlambda, deps_local, sigma_n_local)
return sigma_local - sigma_n_local - C_elas @ (deps_local - deps_p_local)
def r_f(sigma_local, dlambda, deps_local, sigma_n_local):
sigma_elas_local = sigma_n_local + C_elas @ deps_local
yielding = f(sigma_elas_local)
def r_f_elastic(sigma_local, dlambda):
return dlambda
def r_f_plastic(sigma_local, dlambda):
return f(sigma_local)
return jax.lax.cond(yielding <= 0.0, r_f_elastic, r_f_plastic, sigma_local, dlambda)
def r(x_local, deps_local, sigma_n_local):
sigma_local = x_local[:stress_dim]
dlambda_local = x_local[-1]
res_g = r_g(sigma_local, dlambda_local, deps_local, sigma_n_local)
res_f = r_f(sigma_local, dlambda_local, deps_local, sigma_n_local)
res = jnp.c_["0,1,-1", res_g, res_f] # concatenates an array and a scalar
return res
drdx = jax.jacfwd(r)
Then we define the function return_mapping
that implements the
return-mapping algorithm numerically via the Newton method.
Nitermax, tol = 200, 1e-8
ZERO_SCALAR = np.array([0.0])
def return_mapping(deps_local, sigma_n_local):
"""Performs the return-mapping procedure.
It solves elastoplastic constitutive equations numerically by applying the
Newton method in a single Gauss point. The Newton loop is implement via
`jax.lax.while_loop`.
The function returns `sigma_local` two times to reuse its values after
differentiation, i.e. as once we apply
`jax.jacfwd(return_mapping, has_aux=True)` the ouput function will
have an output of
`(C_tang_local, (sigma_local, niter_total, yielding, norm_res, dlambda))`.
Returns:
sigma_local: The stress at the current Gauss point.
niter_total: The total number of iterations.
yielding: The value of the yield function.
norm_res: The norm of the residuals.
dlambda: The value of the plastic multiplier.
"""
niter = 0
dlambda = ZERO_SCALAR
sigma_local = sigma_n_local
x_local = jnp.concatenate([sigma_local, dlambda])
res = r(x_local, deps_local, sigma_n_local)
norm_res0 = jnp.linalg.norm(res)
def cond_fun(state):
norm_res, niter, _ = state
return jnp.logical_and(norm_res / norm_res0 > tol, niter < Nitermax)
def body_fun(state):
norm_res, niter, history = state
x_local, deps_local, sigma_n_local, res = history
j = drdx(x_local, deps_local, sigma_n_local)
j_inv_vp = jnp.linalg.solve(j, -res)
x_local = x_local + j_inv_vp
res = r(x_local, deps_local, sigma_n_local)
norm_res = jnp.linalg.norm(res)
history = x_local, deps_local, sigma_n_local, res
niter += 1
return (norm_res, niter, history)
history = (x_local, deps_local, sigma_n_local, res)
norm_res, niter_total, x_local = jax.lax.while_loop(cond_fun, body_fun, (norm_res0, niter, history))
sigma_local = x_local[0][:stress_dim]
dlambda = x_local[0][-1]
sigma_elas_local = C_elas @ deps_local
yielding = f(sigma_n_local + sigma_elas_local)
return sigma_local, (sigma_local, niter_total, yielding, norm_res, dlambda)
Consistent tangent stiffness matrix#
Not only is the automatic differentiation able to compute the derivative of a mathematical expression but also a numerical algorithm. For instance, AD can calculate the derivative of the function performing return-mapping with respect to its output, the stress tensor \(\boldsymbol{\sigma}\). In the context of the consistent tangent matrix \(\boldsymbol{C}_\text{tang}\), this feature becomes very useful, as there is no need to write an additional program computing the stress derivative.
JAX’s AD tool permits taking the derivative of the function return_mapping
,
which is factually the while loop. The derivative is taken with respect to the
first output and the remaining outputs are used as auxiliary data. Thus, the
derivative dsigma_ddeps
returns both values of the consistent tangent matrix
and the stress tensor, so there is no need in a supplementary computation of the
stress tensor.
dsigma_ddeps = jax.jacfwd(return_mapping, has_aux=True)
Defining external operator#
Once we define the function dsigma_ddeps
, which evaluates both the
external operator and its derivative locally, we can simply vectorize it and
define the final implementation of the external operator derivative.
Note
The function dsigma_ddeps
containing a while_loop
is designed to be called
at a single Gauss point that’s why we need to vectorize it for the all points
of our functional space S
. For this purpose we use the vmap
function of JAX.
It creates another while_loop
, which terminates only when all mapped loops
terminate. Find further details in this
discussion.
dsigma_ddeps_vec = jax.jit(jax.vmap(dsigma_ddeps, in_axes=(0, 0)))
def C_tang_impl(deps):
deps_ = deps.reshape((-1, stress_dim))
sigma_n_ = sigma_n.x.array.reshape((-1, stress_dim))
(C_tang_global, state) = dsigma_ddeps_vec(deps_, sigma_n_)
sigma_global, niter, yielding, norm_res, dlambda = state
unique_iters, counts = jnp.unique(niter, return_counts=True)
print("\tInner Newton summary:")
print(f"\t\tUnique number of iterations: {unique_iters}")
print(f"\t\tCounts of unique number of iterations: {counts}")
print(f"\t\tMaximum f: {jnp.max(yielding)}")
print(f"\t\tMaximum residual: {jnp.max(norm_res)}")
return C_tang_global.reshape(-1), sigma_global.reshape(-1)
Similarly to the von Mises example, we do not implement explicitly the evaluation of the external operator. Instead, we obtain its values during the evaluation of its derivative and then update the values of the operator in the main Newton loop.
def sigma_external(derivatives):
if derivatives == (1,):
return C_tang_impl
else:
return NotImplementedError
sigma.external_function = sigma_external
Defining the forms#
q = fem.Constant(domain, default_scalar_type((0, -gamma)))
def F_ext(v):
return ufl.dot(q, v) * dx
u_hat = ufl.TrialFunction(V)
F = ufl.inner(epsilon(u_), sigma) * dx - F_ext(u_)
J = ufl.derivative(F, Du, u_hat)
J_expanded = ufl.algorithms.expand_derivatives(J)
F_replaced, F_external_operators = replace_external_operators(F)
J_replaced, J_external_operators = replace_external_operators(J_expanded)
F_form = fem.form(F_replaced)
J_form = fem.form(J_replaced)
Variables initialization and compilation#
Before solving the problem we have to initialize values of the consistent tangent matrix, as it requires for the system assembling. During the first load step, we expect an elastic response only, so it’s enough two to solve the constitutive equations for any small displacements at each Gauss point. This results in initializing the consistent tangent matrix with elastic moduli.
At the same time, we can measure the compilation overhead caused by the first call of JIT-ed JAX functions.
Du.x.array[:] = 1.0
sigma_n.x.array[:] = 0.0
timer = common.Timer("DOLFINx_timer")
timer.start()
evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
timer.stop()
pass_1 = timer.elapsed()[0]
timer.start()
evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
timer.stop()
pass_2 = timer.elapsed()[0]
print(f"\nJAX's JIT compilation overhead: {pass_1 - pass_2}")
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -2.2109628558449494
Maximum residual: 0.0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -2.2109628558449494
Maximum residual: 0.0
JAX's JIT compilation overhead: 17.53
Solving the problem#
Summing up, we apply the Newton method to solve the main weak problem. On each iteration of the main Newton loop, we solve elastoplastic constitutive equations by using the second, inner, Newton method at each Gauss point. Thanks to the framework and the JAX library, the final interface is general enough to be applied to other plasticity models.
external_operator_problem = LinearProblem(J_replaced, -F_replaced, Du, bcs=bcs)
x_point = np.array([[0, H, 0]])
cells, points_on_process = find_cell_by_point(domain, x_point)
# parameters of the manual Newton method
max_iterations, relative_tolerance = 200, 1e-8
load_steps_1 = np.linspace(2, 21, 40)
load_steps_2 = np.linspace(21, 22.75, 20)[1:]
load_steps = np.concatenate([load_steps_1, load_steps_2])
num_increments = len(load_steps)
results = np.zeros((num_increments + 1, 2))
for i, load in enumerate(load_steps):
q.value = load * np.array([0, -gamma])
external_operator_problem.assemble_vector()
residual_0 = external_operator_problem.b.norm()
residual = residual_0
Du.x.array[:] = 0
if MPI.COMM_WORLD.rank == 0:
print(f"Load increment #{i}, load: {load}, initial residual: {residual_0}")
for iteration in range(0, max_iterations):
if residual / residual_0 < relative_tolerance:
break
if MPI.COMM_WORLD.rank == 0:
print(f"\tOuter Newton iteration #{iteration}")
external_operator_problem.assemble_matrix()
external_operator_problem.solve(du)
Du.vector.axpy(1.0, du.vector)
Du.x.scatter_forward()
evaluated_operands = evaluate_operands(F_external_operators)
((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
# Direct access to the external operator values
sigma.ref_coefficient.x.array[:] = sigma_new
external_operator_problem.assemble_vector()
residual = external_operator_problem.b.norm()
if MPI.COMM_WORLD.rank == 0:
print(f"\tResidual: {residual}\n")
u.vector.axpy(1.0, Du.vector)
u.x.scatter_forward()
sigma_n.x.array[:] = sigma.ref_coefficient.x.array
if len(points_on_process) > 0:
results[i + 1, :] = (u.eval(points_on_process, cells)[0], load)
print(f"Slope stability factor: {-q.value[-1]*H/c}")
Load increment #0, load: 2.0, initial residual: 0.027573900703382316
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.2983932974714918
Maximum residual: 0.0
Residual: 7.890919449890783e-14
Load increment #1, load: 2.4871794871794872, initial residual: 0.006716719402111263
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14999 1]
Maximum f: 0.28795356169797115
Maximum residual: 1.6594296223543263e-15
Residual: 0.002218734438292365
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14999 1]
Maximum f: 0.4490497372726332
Maximum residual: 1.520953013893946e-10
Residual: 8.901387733059753e-05
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14999 1]
Maximum f: 0.4526047639809363
Maximum residual: 1.7843660861833476e-10
Residual: 8.728694698004881e-08
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14999 1]
Maximum f: 0.4526085284602561
Maximum residual: 1.7846736539222126e-10
Residual: 8.984456565965332e-14
Load increment #2, load: 2.9743589743589745, initial residual: 0.006716719402105831
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14998 1 1]
Maximum f: 0.9590600597654926
Maximum residual: 1.775929382738241e-11
Residual: 0.004802356196802484
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14998 1 1]
Maximum f: 1.1348741666886801
Maximum residual: 4.090145411458145e-09
Residual: 9.211659039241788e-05
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14998 1 1]
Maximum f: 1.1350924176788273
Maximum residual: 4.525723849207466e-09
Residual: 1.693162840012877e-08
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14998 1 1]
Maximum f: 1.1350931606257837
Maximum residual: 4.525858231395645e-09
Residual: 6.1782428917217444e-15
Load increment #3, load: 3.4615384615384617, initial residual: 0.006716719402106011
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14996 3 1]
Maximum f: 1.209400553063301
Maximum residual: 7.248275016699231e-09
Residual: 0.004577764018417838
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14995 4 1]
Maximum f: 1.3971321183935568
Maximum residual: 1.7753406791897168e-08
Residual: 0.0008492026933011995
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14995 4 1]
Maximum f: 1.4248006611944821
Maximum residual: 1.3143373292644214e-08
Residual: 5.803956088157307e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14995 4 1]
Maximum f: 1.4248654098565905
Maximum residual: 1.3115351889884023e-08
Residual: 1.7549208096475178e-10
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4]
Counts of unique number of iterations: [14995 4 1]
Maximum f: 1.424865413631483
Maximum residual: 1.3115347787810599e-08
Residual: 6.274298032534265e-15
Load increment #4, load: 3.948717948717949, initial residual: 0.006716719402106086
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14992 5 1 2]
Maximum f: 1.6622114112939035
Maximum residual: 1.6729879566299043e-09
Residual: 0.005991366712426636
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14993 2 3 2]
Maximum f: 1.9712452063909374
Maximum residual: 1.9702586209052233e-08
Residual: 0.0008411325822618858
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14993 2 3 2]
Maximum f: 1.9669201261015217
Maximum residual: 9.841984594686828e-09
Residual: 1.0331838834946247e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14993 2 3 2]
Maximum f: 1.966990437644761
Maximum residual: 9.885406143576431e-09
Residual: 6.5642790536169226e-12
Load increment #5, load: 4.435897435897436, initial residual: 0.006716719402089788
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14988 5 6 1]
Maximum f: 2.162445279431028
Maximum residual: 1.71677022158221e-08
Residual: 0.0063871310319332025
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14986 8 4 2]
Maximum f: 2.2145902775297404
Maximum residual: 7.373647765391979e-09
Residual: 0.0018311090812132959
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14986 8 4 2]
Maximum f: 2.2478311837587746
Maximum residual: 1.0295518787805423e-08
Residual: 6.186049762218369e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14986 8 4 2]
Maximum f: 2.24784731708579
Maximum residual: 1.030700085508072e-08
Residual: 1.1402754423442121e-10
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14986 8 4 2]
Maximum f: 2.2478473195050532
Maximum residual: 1.0307001773475238e-08
Residual: 6.214271459043541e-15
Load increment #6, load: 4.923076923076923, initial residual: 0.00671671940210594
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14980 10 8 2]
Maximum f: 2.4123407381238366
Maximum residual: 1.3702129597133274e-08
Residual: 0.007208784369928679
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14978 11 9 2]
Maximum f: 2.637309208631247
Maximum residual: 6.074572172319527e-09
Residual: 0.0010890919977788044
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14978 11 9 2]
Maximum f: 2.657703883779909
Maximum residual: 5.0240029089825934e-09
Residual: 1.2600502802830214e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14978 11 9 2]
Maximum f: 2.6578217315765156
Maximum residual: 5.0069635910571806e-09
Residual: 1.2853246080780249e-09
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14978 11 9 2]
Maximum f: 2.6578217438363816
Maximum residual: 5.006962713000526e-09
Residual: 6.323108216488874e-15
Load increment #7, load: 5.410256410256411, initial residual: 0.0067167194021057995
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14971 14 12 3]
Maximum f: 2.897228427974882
Maximum residual: 3.061045583467829e-08
Residual: 0.007185163071562447
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14970 13 12 5]
Maximum f: 3.074460145228296
Maximum residual: 9.37743561600887e-09
Residual: 0.0036353854676650877
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14969 16 10 5]
Maximum f: 3.106835325618557
Maximum residual: 2.071550588827723e-08
Residual: 0.0006614649106003488
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14969 16 10 5]
Maximum f: 3.1113366861625216
Maximum residual: 2.086743436134511e-08
Residual: 2.3272857870653557e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14969 16 10 5]
Maximum f: 3.1113629293906118
Maximum residual: 2.087013565612822e-08
Residual: 1.601845389690442e-11
Load increment #8, load: 5.897435897435898, initial residual: 0.006716719402095755
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14956 1 26 15 2]
Maximum f: 3.3170712529546322
Maximum residual: 1.7622990960697624e-08
Residual: 0.008492114261895212
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14957 24 16 3]
Maximum f: 3.6710312111713876
Maximum residual: 2.3222829685364223e-08
Residual: 0.005488364093321883
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14956 25 16 3]
Maximum f: 3.691435244584857
Maximum residual: 2.551361351098158e-08
Residual: 8.3194382111664e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14956 25 16 3]
Maximum f: 3.691421400465235
Maximum residual: 2.5406554782478082e-08
Residual: 1.4048353900921005e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14956 25 16 3]
Maximum f: 3.691421599329771
Maximum residual: 2.540632866486787e-08
Residual: 1.5064173841071146e-13
Load increment #9, load: 6.384615384615384, initial residual: 0.006716719402106365
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14936 46 15 3]
Maximum f: 4.045554160772726
Maximum residual: 4.046223127970408e-08
Residual: 0.008836812772120842
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14935 42 18 5]
Maximum f: 4.316638567616666
Maximum residual: 4.943413709278408e-08
Residual: 0.010718894090114112
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14934 43 17 6]
Maximum f: 4.355314378915885
Maximum residual: 1.706109959545318e-08
Residual: 0.0002441838429221422
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14934 43 17 6]
Maximum f: 4.356975910370981
Maximum residual: 1.751551077434608e-08
Residual: 1.3661173343927448e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14934 43 17 6]
Maximum f: 4.356976299040825
Maximum residual: 1.7515666822217637e-08
Residual: 5.14863847162322e-14
Load increment #10, load: 6.871794871794871, initial residual: 0.00671671940210583
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14915 59 23 3]
Maximum f: 4.73899401300072
Maximum residual: 5.3984419776923646e-08
Residual: 0.008959963182434999
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14915 55 24 6]
Maximum f: 4.994258500258619
Maximum residual: 2.3250001786574396e-08
Residual: 0.011397174045726712
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14913 56 25 6]
Maximum f: 5.018175039270362
Maximum residual: 2.3162595968289013e-08
Residual: 0.00034539497553902706
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14913 56 25 6]
Maximum f: 5.021313077908896
Maximum residual: 2.3185394981555668e-08
Residual: 2.2738139720737685e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14913 56 25 6]
Maximum f: 5.021314343448537
Maximum residual: 2.3185398022509775e-08
Residual: 1.3678886974249533e-13
Load increment #11, load: 7.358974358974359, initial residual: 0.006716719402106141
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14895 1 69 32 3]
Maximum f: 5.425401250002199
Maximum residual: 3.385494654584838e-08
Residual: 0.008873442760279588
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14892 1 70 30 7]
Maximum f: 5.660279596506831
Maximum residual: 3.118953666047237e-08
Residual: 0.0015471905163801738
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14892 1 71 29 7]
Maximum f: 5.677810657796213
Maximum residual: 4.7779089907947863e-08
Residual: 1.083841467135265e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14892 1 71 29 7]
Maximum f: 5.677921597143509
Maximum residual: 4.782765723501018e-08
Residual: 3.9821665570590263e-10
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14892 1 71 29 7]
Maximum f: 5.677921600559749
Maximum residual: 4.782766439864943e-08
Residual: 6.718653015297411e-15
Load increment #12, load: 7.846153846153846, initial residual: 0.006716719402105915
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14876 84 37 3]
Maximum f: 6.049350844711803
Maximum residual: 5.537682466709748e-08
Residual: 0.00909635116112651
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14872 85 39 4]
Maximum f: 6.2395293645867795
Maximum residual: 4.33070938661249e-08
Residual: 0.0016211998546290778
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14871 86 39 4]
Maximum f: 6.250768031311917
Maximum residual: 6.110279334605412e-08
Residual: 0.00015603184395129548
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14871 86 39 4]
Maximum f: 6.2502649319295145
Maximum residual: 6.212927904010676e-08
Residual: 1.0783311657203034e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14871 86 39 4]
Maximum f: 6.250266596016187
Maximum residual: 6.212979557458811e-08
Residual: 2.7384705549305455e-14
Load increment #13, load: 8.333333333333332, initial residual: 0.0067167194021057995
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14856 100 42 2]
Maximum f: 6.585636895195892
Maximum residual: 5.193548486682376e-08
Residual: 0.009007878894704023
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14855 1 97 45 2]
Maximum f: 6.755125621186041
Maximum residual: 5.688046285266084e-08
Residual: 0.0012749818243058947
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14853 1 99 45 2]
Maximum f: 6.765747790570044
Maximum residual: 6.443474499515509e-08
Residual: 0.0003032242899863522
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14853 1 99 45 2]
Maximum f: 6.767071381472654
Maximum residual: 6.496357412460641e-08
Residual: 1.489896587160495e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14853 1 99 45 2]
Maximum f: 6.767071961553549
Maximum residual: 6.496390166810064e-08
Residual: 6.225101140482484e-14
Load increment #14, load: 8.820512820512821, initial residual: 0.0067167194021065975
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 3 4 5]
Counts of unique number of iterations: [14840 113 46 1]
Maximum f: 7.113663468184065
Maximum residual: 4.368312081897168e-08
Residual: 0.008041984419028435
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14836 2 112 48 2]
Maximum f: 7.252960401035864
Maximum residual: 7.477724366204454e-08
Residual: 0.001324125165071309
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14834 1 115 48 2]
Maximum f: 7.25737579344905
Maximum residual: 6.980241274251221e-08
Residual: 7.75583446220206e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14834 1 115 48 2]
Maximum f: 7.25709827678001
Maximum residual: 6.908151030391233e-08
Residual: 2.3409400728015878e-08
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14834 1 115 48 2]
Maximum f: 7.257098622929357
Maximum residual: 6.908160534971312e-08
Residual: 7.253142220446733e-15
Load increment #15, load: 9.307692307692307, initial residual: 0.00671671940210586
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14823 2 130 43 2]
Maximum f: 7.58041617721587
Maximum residual: 4.595512338717213e-08
Residual: 0.006820691009076998
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14820 3 127 48 2]
Maximum f: 7.693992069459707
Maximum residual: 4.3012469257073165e-08
Residual: 0.0004652039687665085
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14820 3 126 49 2]
Maximum f: 7.69587120975468
Maximum residual: 4.2245463838736303e-08
Residual: 5.193523274965661e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14820 3 126 49 2]
Maximum f: 7.695876907437409
Maximum residual: 4.224525949953697e-08
Residual: 7.659190506996182e-13
Load increment #16, load: 9.794871794871796, initial residual: 0.006716719402108412
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14808 2 142 46 2]
Maximum f: 7.990509748994471
Maximum residual: 4.004581026336205e-08
Residual: 0.006701736964370909
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14805 4 139 50 2]
Maximum f: 8.08955046110169
Maximum residual: 3.9413228833501414e-08
Residual: 0.0008807439366930102
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14805 4 139 50 2]
Maximum f: 8.091632218673137
Maximum residual: 3.3707938270779526e-08
Residual: 2.8174935530515266e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14805 4 139 50 2]
Maximum f: 8.091634655377234
Maximum residual: 3.3706325363099663e-08
Residual: 3.626992861479351e-11
Load increment #17, load: 10.282051282051281, initial residual: 0.0067167194021178775
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14795 5 154 41 5]
Maximum f: 8.358501303640198
Maximum residual: 5.0813990167032605e-08
Residual: 0.006330612673184364
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14793 3 154 45 5]
Maximum f: 8.437903735293917
Maximum residual: 5.021345716229347e-08
Residual: 0.00045101693866335786
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14793 3 154 45 5]
Maximum f: 8.44003483745481
Maximum residual: 5.032880438274941e-08
Residual: 6.877521849996479e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14793 3 154 45 5]
Maximum f: 8.440036781373985
Maximum residual: 5.0328930922830774e-08
Residual: 1.6053143784455589e-12
Load increment #18, load: 10.769230769230768, initial residual: 0.006716719402110584
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14780 3 177 37 3]
Maximum f: 8.68456474312378
Maximum residual: 9.572256835248304e-08
Residual: 0.006321576458610434
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14779 4 171 43 3]
Maximum f: 8.752245455320418
Maximum residual: 4.244759979410072e-08
Residual: 0.000610492826268375
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14778 4 172 43 3]
Maximum f: 8.752935060754327
Maximum residual: 4.2515114894425285e-08
Residual: 0.00023235845386350724
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14778 3 173 43 3]
Maximum f: 8.753156792962672
Maximum residual: 4.252849529104911e-08
Residual: 2.97884917091799e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14778 3 173 43 3]
Maximum f: 8.753157018624599
Maximum residual: 4.2528507929369336e-08
Residual: 3.806034398229405e-13
Load increment #19, load: 11.256410256410255, initial residual: 0.006716719402108087
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14765 3 195 36 1]
Maximum f: 8.95829439775059
Maximum residual: 8.484668722165286e-08
Residual: 0.008132000178931911
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14762 2 196 38 2]
Maximum f: 8.998791069662573
Maximum residual: 1.1058616519530065e-07
Residual: 0.001997392228008905
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14762 2 196 38 2]
Maximum f: 9.005204366112117
Maximum residual: 1.1066259012209936e-07
Residual: 2.525996161972977e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14762 2 196 38 2]
Maximum f: 9.005208507754293
Maximum residual: 1.106641134603762e-07
Residual: 2.4448539867613254e-09
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14762 2 196 38 2]
Maximum f: 9.005208509234254
Maximum residual: 1.1066411422084699e-07
Residual: 7.870182825691626e-15
Load increment #20, load: 11.743589743589743, initial residual: 0.006716719402105844
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14746 4 221 26 3]
Maximum f: 9.133896502162614
Maximum residual: 4.002586471346086e-08
Residual: 0.008896728060078223
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14745 3 218 31 3]
Maximum f: 9.192693190649075
Maximum residual: 6.888309926607873e-08
Residual: 0.0038934284315613244
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14745 4 217 31 3]
Maximum f: 9.19757041583747
Maximum residual: 6.285697180274128e-08
Residual: 3.79684806331697e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14745 4 217 31 3]
Maximum f: 9.197584054900588
Maximum residual: 6.28436612625426e-08
Residual: 4.2398694418568105e-09
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14745 4 217 31 3]
Maximum f: 9.197584057484399
Maximum residual: 6.284365745589517e-08
Residual: 7.851337464868925e-15
Load increment #21, load: 12.23076923076923, initial residual: 0.006716719402105858
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14728 6 243 19 4]
Maximum f: 9.309317989824686
Maximum residual: 5.0039574973725875e-08
Residual: 0.009216976218156541
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14727 5 242 22 4]
Maximum f: 9.372210269033676
Maximum residual: 4.729179352403944e-08
Residual: 0.005677561618555902
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14728 5 241 22 4]
Maximum f: 9.375041973280801
Maximum residual: 4.859850380520213e-08
Residual: 0.00019877358899286256
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14728 5 241 22 4]
Maximum f: 9.375115305044817
Maximum residual: 4.86429072529176e-08
Residual: 7.633714921926549e-09
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14728 5 241 22 4]
Maximum f: 9.375115306846656
Maximum residual: 4.8642903504418145e-08
Residual: 7.8710856118586e-15
Load increment #22, load: 12.717948717948717, initial residual: 0.006716719402106297
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14703 7 276 10 4]
Maximum f: 9.63173443468616
Maximum residual: 7.02878085279497e-08
Residual: 0.010103758455058044
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14703 7 270 16 4]
Maximum f: 10.266068194665394
Maximum residual: 4.788828080101722e-08
Residual: 0.005635378509327741
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14702 7 272 15 4]
Maximum f: 10.29453701121926
Maximum residual: 1.2234570985676737e-07
Residual: 0.000645306827327642
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14703 6 272 15 4]
Maximum f: 10.305089262350577
Maximum residual: 1.212167964781357e-07
Residual: 2.377198286444656e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14703 6 272 15 4]
Maximum f: 10.30514001407545
Maximum residual: 1.2121605292000952e-07
Residual: 3.2049746246585445e-12
Load increment #23, load: 13.205128205128204, initial residual: 0.006716719402092745
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14669 12 304 13 2]
Maximum f: 11.025940478436011
Maximum residual: 5.2872662692572636e-08
Residual: 0.010006180815023176
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14666 6 314 12 2]
Maximum f: 12.212967154051473
Maximum residual: 3.4404539287631105e-08
Residual: 0.006138534558733101
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14665 8 313 12 2]
Maximum f: 12.289198038530676
Maximum residual: 4.4927080275289845e-08
Residual: 0.00044609546246340874
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14666 7 313 12 2]
Maximum f: 12.292335547743185
Maximum residual: 4.7197448737045196e-08
Residual: 2.918169783237978e-05
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14666 7 313 12 2]
Maximum f: 12.292564839552277
Maximum residual: 4.7248283087204915e-08
Residual: 9.409193163893957e-10
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14666 7 313 12 2]
Maximum f: 12.292564838148865
Maximum residual: 4.724828385548392e-08
Residual: 8.393446094006253e-15
Load increment #24, load: 13.692307692307692, initial residual: 0.006716719402105958
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14626 10 349 13 2]
Maximum f: 13.752176654898795
Maximum residual: 7.098611881554246e-08
Residual: 0.009288087620594824
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14628 3 354 13 2]
Maximum f: 14.526046495213562
Maximum residual: 6.025212812674243e-08
Residual: 0.017487824708030283
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14628 4 353 13 2]
Maximum f: 14.630545477040274
Maximum residual: 7.443567043345022e-08
Residual: 0.0011315236982134143
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14628 4 353 13 2]
Maximum f: 14.644482393886564
Maximum residual: 7.505198406562835e-08
Residual: 1.407048587391681e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14628 4 353 13 2]
Maximum f: 14.64448759202658
Maximum residual: 7.505278213247118e-08
Residual: 3.954982987435764e-12
Load increment #25, load: 14.179487179487179, initial residual: 0.006716719402093062
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14585 12 389 12 2]
Maximum f: 15.668897835568425
Maximum residual: 1.0359324282568697e-07
Residual: 0.010202728243486267
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14585 4 395 14 2]
Maximum f: 16.897476127769465
Maximum residual: 7.49218816485522e-08
Residual: 0.018204886208027227
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14586 6 393 13 2]
Maximum f: 17.005942479359888
Maximum residual: 1.0354994585762806e-07
Residual: 0.001766524529136301
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14586 6 393 13 2]
Maximum f: 17.01076070657941
Maximum residual: 1.0239552089224122e-07
Residual: 2.4867687483445258e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14586 6 393 13 2]
Maximum f: 17.010757834581053
Maximum residual: 1.0239377241020804e-07
Residual: 1.1557380703554376e-11
Load increment #26, load: 14.666666666666666, initial residual: 0.006716719402090528
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14537 12 438 12 1]
Maximum f: 18.39294355205917
Maximum residual: 8.928276620617958e-08
Residual: 0.010739219905001303
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14537 5 442 15 1]
Maximum f: 19.448449565617622
Maximum residual: 1.143223709998003e-07
Residual: 0.02179331102773085
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14537 6 440 16 1]
Maximum f: 19.593818367596622
Maximum residual: 1.0728628556114985e-07
Residual: 4.878299983263898e-05
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14537 6 440 16 1]
Maximum f: 19.59392537028595
Maximum residual: 1.074182263339286e-07
Residual: 6.6946089234570265e-09
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14537 6 440 16 1]
Maximum f: 19.593925341927573
Maximum residual: 1.0741824914500163e-07
Residual: 8.886468278163498e-15
Load increment #27, load: 15.153846153846153, initial residual: 0.006716719402105671
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14480 12 490 17 1]
Maximum f: 20.96674898559733
Maximum residual: 7.974201419472802e-08
Residual: 0.010417054077772969
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14484 8 490 17 1]
Maximum f: 22.221434026645674
Maximum residual: 5.892301718550108e-08
Residual: 0.02632335173036927
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14481 9 492 17 1]
Maximum f: 22.41046674910591
Maximum residual: 8.175009670862007e-08
Residual: 0.0009187005860068906
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14481 8 493 17 1]
Maximum f: 22.405318415035364
Maximum residual: 8.272589945751361e-08
Residual: 6.865470019662356e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14481 8 493 17 1]
Maximum f: 22.405274475571368
Maximum residual: 8.27281108923625e-08
Residual: 2.7254233159478655e-10
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14481 8 493 17 1]
Maximum f: 22.405274473192478
Maximum residual: 8.272810964860568e-08
Residual: 9.12277783847181e-15
Load increment #28, load: 15.64102564102564, initial residual: 0.006716719402106747
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14413 12 559 14 2]
Maximum f: 24.305554899323425
Maximum residual: 1.1422847452394309e-07
Residual: 0.010984385486294516
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14418 8 555 17 2]
Maximum f: 25.374448302940365
Maximum residual: 5.85257914164579e-08
Residual: 0.023822453291475986
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14415 7 559 17 2]
Maximum f: 25.524100384956604
Maximum residual: 5.2111628962448404e-08
Residual: 0.000632110912396084
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14415 6 560 17 2]
Maximum f: 25.53493731529084
Maximum residual: 5.280219619940484e-08
Residual: 1.7406786524159857e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14415 6 560 17 2]
Maximum f: 25.53494661804102
Maximum residual: 5.2802884829194224e-08
Residual: 1.4314551061912073e-11
Load increment #29, load: 16.128205128205128, initial residual: 0.006716719402092326
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14353 17 608 20 2]
Maximum f: 27.38821927530687
Maximum residual: 4.498610418770207e-08
Residual: 0.01220183059585991
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14356 11 612 19 2]
Maximum f: 28.584557985385988
Maximum residual: 7.832680090454439e-08
Residual: 0.019788844210587617
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14353 11 615 19 2]
Maximum f: 28.668309836300896
Maximum residual: 8.144229824691069e-08
Residual: 0.0009708591755653929
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14353 11 615 19 2]
Maximum f: 28.66667178001655
Maximum residual: 8.170143359972085e-08
Residual: 8.220389225920535e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14353 11 615 19 2]
Maximum f: 28.66664018384156
Maximum residual: 8.170199798887568e-08
Residual: 3.342775554120138e-10
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14353 11 615 19 2]
Maximum f: 28.666640181777673
Maximum residual: 8.17019934029446e-08
Residual: 9.883786062574305e-15
Load increment #30, load: 16.615384615384613, initial residual: 0.0067167194021054985
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14291 19 668 21 1]
Maximum f: 30.339192570017275
Maximum residual: 3.2793459871275615e-07
Residual: 0.014781056742039755
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14283 18 679 19 1]
Maximum f: 31.862611315017126
Maximum residual: 9.667634863582554e-08
Residual: 0.009732116272655258
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14281 19 679 20 1]
Maximum f: 32.10647961251519
Maximum residual: 7.949837076224848e-08
Residual: 0.0024541551144318507
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14281 19 679 20 1]
Maximum f: 32.12478052643899
Maximum residual: 8.192197620886831e-08
Residual: 1.6632697760450347e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14281 19 679 20 1]
Maximum f: 32.124794432722226
Maximum residual: 8.192384015557072e-08
Residual: 5.307261651691642e-12
Load increment #31, load: 17.102564102564102, initial residual: 0.006716719402101556
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14206 16 750 26 2]
Maximum f: 34.25259429086194
Maximum residual: 5.6444343498231866e-08
Residual: 0.01297381597457321
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14197 17 760 24 2]
Maximum f: 35.49385941710134
Maximum residual: 9.649000998659729e-08
Residual: 0.013085384122132166
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14194 14 766 24 2]
Maximum f: 35.672142068577976
Maximum residual: 1.0100674615709845e-07
Residual: 0.000476777805004166
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14194 12 768 24 2]
Maximum f: 35.68206018085639
Maximum residual: 1.0120665527535234e-07
Residual: 5.414426119610152e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14194 12 768 24 2]
Maximum f: 35.68206441431284
Maximum residual: 1.0120679725094297e-07
Residual: 1.0066942045793446e-12
Load increment #32, load: 17.58974358974359, initial residual: 0.006716719402103371
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14109 20 842 27 2]
Maximum f: 37.821780614346245
Maximum residual: 8.912766165850012e-08
Residual: 0.013265122380362617
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14103 13 853 29 2]
Maximum f: 39.33636620367054
Maximum residual: 1.8231828705124016e-07
Residual: 0.017344808348780303
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14102 16 851 29 2]
Maximum f: 39.4609815854675
Maximum residual: 1.8949256158985875e-07
Residual: 0.002786063696087652
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14101 17 851 29 2]
Maximum f: 39.4946982973061
Maximum residual: 1.9103679118206756e-07
Residual: 3.150485872403823e-05
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14101 17 851 29 2]
Maximum f: 39.49493569259684
Maximum residual: 1.9104773239249382e-07
Residual: 6.003931266267597e-09
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14101 17 851 29 2]
Maximum f: 39.49493570942146
Maximum residual: 1.9104773439967617e-07
Residual: 1.1392532511237408e-14
Load increment #33, load: 18.076923076923077, initial residual: 0.006716719402106395
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [14007 26 937 28 2]
Maximum f: 41.9098201180838
Maximum residual: 1.7190098587950488e-07
Residual: 0.013883417561548022
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13994 25 947 33 1]
Maximum f: 43.305826051098165
Maximum residual: 5.715168094260168e-08
Residual: 0.013918975069064685
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13991 24 950 34 1]
Maximum f: 43.46183461309083
Maximum residual: 8.336720141648832e-08
Residual: 0.0006913086206251028
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13991 24 950 34 1]
Maximum f: 43.47158188535475
Maximum residual: 8.532034846251978e-08
Residual: 1.3112001597501342e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13991 24 950 34 1]
Maximum f: 43.47158513421427
Maximum residual: 8.532191981945231e-08
Residual: 3.627195920430388e-12
Load increment #34, load: 18.564102564102562, initial residual: 0.006716719402095581
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13902 22 1042 33 1]
Maximum f: 45.91480212969232
Maximum residual: 2.5231995737558407e-07
Residual: 0.013815534918141121
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13893 22 1049 35 1]
Maximum f: 47.62465319970073
Maximum residual: 1.133215878935869e-07
Residual: 0.016106095053431354
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13889 20 1055 35 1]
Maximum f: 47.78139392957633
Maximum residual: 3.379388837129073e-08
Residual: 0.0007718361445462001
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13888 18 1058 35 1]
Maximum f: 47.79394068855864
Maximum residual: 9.146917710432406e-08
Residual: 1.920616523282387e-05
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13888 18 1058 35 1]
Maximum f: 47.79397066659853
Maximum residual: 9.139406835873182e-08
Residual: 4.1750732579912033e-10
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13888 18 1058 35 1]
Maximum f: 47.7939706688178
Maximum residual: 9.13940654724866e-08
Residual: 1.1903612989723599e-14
Load increment #35, load: 19.05128205128205, initial residual: 0.006716719402106123
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13785 28 1152 33 2]
Maximum f: 50.14385585812778
Maximum residual: 2.7897042243367246e-07
Residual: 0.015919379907346706
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13779 18 1167 34 2]
Maximum f: 52.23835524504635
Maximum residual: 4.877148053583346e-08
Residual: 0.01614774566949417
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13779 24 1160 35 2]
Maximum f: 52.44076507108261
Maximum residual: 7.126955269651264e-08
Residual: 0.0013833112871311163
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13778 24 1161 35 2]
Maximum f: 52.45168152789199
Maximum residual: 7.02549186065997e-08
Residual: 4.661538350174757e-05
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13778 24 1161 35 2]
Maximum f: 52.4521277518502
Maximum residual: 7.032483010506796e-08
Residual: 1.5900557642300973e-08
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13778 24 1161 35 2]
Maximum f: 52.45212785463692
Maximum residual: 7.032487354410769e-08
Residual: 1.2504607353543192e-14
Load increment #36, load: 19.538461538461537, initial residual: 0.006716719402105516
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13669 33 1259 36 3]
Maximum f: 55.56784988592688
Maximum residual: 6.732028361136979e-08
Residual: 0.016487333139811673
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13653 28 1280 36 3]
Maximum f: 57.52067499406361
Maximum residual: 6.714560370712857e-08
Residual: 0.00955516114956653
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13649 30 1282 36 3]
Maximum f: 57.732233414592876
Maximum residual: 1.0327807640700029e-07
Residual: 0.0012033324130438262
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13649 29 1283 36 3]
Maximum f: 57.74985451651768
Maximum residual: 1.2704672898236444e-07
Residual: 4.14564476269671e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13649 29 1283 36 3]
Maximum f: 57.749868232462234
Maximum residual: 1.2707298784976296e-07
Residual: 4.720998897913752e-11
Load increment #37, load: 20.025641025641026, initial residual: 0.006716719401962263
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13524 35 1401 39 1]
Maximum f: 61.06990940367414
Maximum residual: 7.239404353078795e-08
Residual: 0.015428771935675076
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13507 39 1413 40 1]
Maximum f: 63.56467957583643
Maximum residual: 1.339659214037089e-07
Residual: 0.01460246934657777
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13506 36 1417 40 1]
Maximum f: 63.79127538974135
Maximum residual: 9.75441023864128e-08
Residual: 0.0021520086173601235
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13505 37 1417 40 1]
Maximum f: 63.80501931122347
Maximum residual: 1.0692960590773911e-07
Residual: 9.969636774588729e-05
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13505 37 1417 40 1]
Maximum f: 63.80645015632711
Maximum residual: 1.0928016991558215e-07
Residual: 9.220161004743297e-08
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13505 37 1417 40 1]
Maximum f: 63.80645060186934
Maximum residual: 1.0928077458087668e-07
Residual: 3.2700114269518456e-14
Load increment #38, load: 20.51282051282051, initial residual: 0.006716719402106495
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13393 36 1531 39 1]
Maximum f: 67.75666130501165
Maximum residual: 4.4283993159215054e-07
Residual: 0.017517429798991468
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13365 31 1561 41 2]
Maximum f: 70.58577902838182
Maximum residual: 8.057051642698636e-08
Residual: 0.015748989628338424
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13365 26 1566 41 2]
Maximum f: 70.93025167118343
Maximum residual: 1.0978850755445194e-07
Residual: 0.0005169866997214393
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13364 27 1566 41 2]
Maximum f: 70.93404932897471
Maximum residual: 1.0829233995382601e-07
Residual: 1.490303654008798e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13364 27 1566 41 2]
Maximum f: 70.93406548829532
Maximum residual: 1.0829344733671326e-07
Residual: 1.209903938918342e-11
Load increment #39, load: 21.0, initial residual: 0.006716719402106013
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13215 38 1706 37 4]
Maximum f: 75.39347837714647
Maximum residual: 8.506697471963536e-08
Residual: 0.01837161956183727
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13194 34 1728 39 5]
Maximum f: 79.15199508212592
Maximum residual: 4.4074659899683117e-07
Residual: 0.01716196534684848
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13191 27 1740 37 5]
Maximum f: 79.58593038621314
Maximum residual: 1.1253553490195802e-07
Residual: 0.0013369308810265517
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13191 27 1740 37 5]
Maximum f: 79.60996466352213
Maximum residual: 1.0675518251501684e-07
Residual: 2.9209339779508552e-06
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [13191 27 1740 37 5]
Maximum f: 79.6099757087359
Maximum residual: 1.0674974716251821e-07
Residual: 3.630111799659412e-11
Load increment #40, load: 21.092105263157894, initial residual: 0.0012698506902034635
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [13165 276 1557 2]
Maximum f: 16.089757616598472
Maximum residual: 3.2908599460270386e-08
Residual: 0.003132131412589753
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [13163 327 1509 1]
Maximum f: 16.108413336488766
Maximum residual: 2.0737017708986704e-07
Residual: 0.0001309750793832361
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [13163 327 1509 1]
Maximum f: 16.111745280753176
Maximum residual: 2.1805079115737293e-07
Residual: 9.146064805775662e-08
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [13163 327 1509 1]
Maximum f: 16.11174628586943
Maximum residual: 2.1805780476752925e-07
Residual: 3.107451965169591e-14
Load increment #41, load: 21.18421052631579, initial residual: 0.0012698506902877028
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13125 336 1539]
Maximum f: 16.28403737522223
Maximum residual: 4.2841471057309377e-08
Residual: 0.0021467124200770507
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13122 331 1547]
Maximum f: 16.49461045204182
Maximum residual: 5.4255894925230015e-08
Residual: 8.480986566737737e-05
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13122 331 1547]
Maximum f: 16.49615522847548
Maximum residual: 5.361516560334869e-08
Residual: 3.1037346206766076e-08
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13122 331 1547]
Maximum f: 16.496155577272127
Maximum residual: 5.361506334752265e-08
Residual: 6.1693080768153555e-15
Load increment #42, load: 21.276315789473685, initial residual: 0.0012698506902871934
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13084 350 1566]
Maximum f: 16.7884132206062
Maximum residual: 5.093052054608339e-08
Residual: 0.002799389435763017
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13081 334 1585]
Maximum f: 17.034439719495957
Maximum residual: 5.3117582675056584e-08
Residual: 0.00043717704714079774
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13081 332 1587]
Maximum f: 17.041642871359343
Maximum residual: 5.5418745429041117e-08
Residual: 1.7721920312567786e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13081 332 1587]
Maximum f: 17.041659487244775
Maximum residual: 5.542202406010224e-08
Residual: 1.1277892132950674e-11
Load increment #43, load: 21.36842105263158, initial residual: 0.001269850690298939
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13049 343 1608]
Maximum f: 17.2786143613338
Maximum residual: 5.6645147653051966e-08
Residual: 0.002593433873319041
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13049 334 1617]
Maximum f: 17.490395122340544
Maximum residual: 5.568600439836382e-08
Residual: 0.000719086134272951
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13049 334 1617]
Maximum f: 17.496370527697085
Maximum residual: 5.446815708418039e-08
Residual: 4.377588363472104e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13049 334 1617]
Maximum f: 17.496373183969098
Maximum residual: 5.446785456859546e-08
Residual: 4.710207314425706e-13
Load increment #44, load: 21.460526315789473, initial residual: 0.0012698506902881189
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13013 351 1636]
Maximum f: 17.714686682904418
Maximum residual: 5.0147234961301736e-08
Residual: 0.00270186706057739
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13011 346 1643]
Maximum f: 17.96512383204705
Maximum residual: 5.338741459191331e-08
Residual: 0.00042164585393472025
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13011 347 1642]
Maximum f: 17.96958205927456
Maximum residual: 5.2649645443039856e-08
Residual: 9.47238924041019e-08
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [13011 347 1642]
Maximum f: 17.96958382006201
Maximum residual: 5.264937977173685e-08
Residual: 2.3160986662794185e-14
Load increment #45, load: 21.55263157894737, initial residual: 0.0012698506902885
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12971 364 1665]
Maximum f: 18.277469754284002
Maximum residual: 5.8774291411304744e-08
Residual: 0.0030246887375881427
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12972 346 1682]
Maximum f: 18.497459684989913
Maximum residual: 5.0034871778360134e-08
Residual: 0.001358882476347059
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12972 344 1684]
Maximum f: 18.525185730811714
Maximum residual: 5.149810670863012e-08
Residual: 1.9054454163196622e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12972 344 1684]
Maximum f: 18.525206055073866
Maximum residual: 5.150080515301985e-08
Residual: 9.810763002019388e-12
Load increment #46, load: 21.644736842105264, initial residual: 0.00126985069032032
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12934 352 1714]
Maximum f: 18.86455367000332
Maximum residual: 5.239548525852323e-08
Residual: 0.0034265177396679168
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12928 344 1728]
Maximum f: 19.125341198308256
Maximum residual: 4.9679574383078104e-08
Residual: 0.001376389292018058
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12928 346 1726]
Maximum f: 19.155549876635167
Maximum residual: 5.0098608406680944e-08
Residual: 1.4515813989821835e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12928 346 1726]
Maximum f: 19.15556356815058
Maximum residual: 5.0099352048911193e-08
Residual: 4.780033890906139e-12
Load increment #47, load: 21.736842105263158, initial residual: 0.0012698506902894579
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12893 356 1751]
Maximum f: 19.459795366327228
Maximum residual: 6.185634534537616e-08
Residual: 0.002792618539106706
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12892 347 1761]
Maximum f: 19.793917140529487
Maximum residual: 3.621899085417811e-08
Residual: 0.0008070305255768422
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12892 346 1762]
Maximum f: 19.804972222490616
Maximum residual: 3.535777161157398e-08
Residual: 3.955582862509976e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12892 346 1762]
Maximum f: 19.80497303654093
Maximum residual: 3.535758552274675e-08
Residual: 2.500985032978027e-13
Load increment #48, load: 21.82894736842105, initial residual: 0.0012698506902874022
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12863 360 1777]
Maximum f: 20.193528253113605
Maximum residual: 5.585553935437403e-08
Residual: 0.0027907396976720607
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12859 359 1782]
Maximum f: 20.46175456840382
Maximum residual: 4.95928473604865e-08
Residual: 0.0003265015952555808
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12859 357 1784]
Maximum f: 20.46885734263891
Maximum residual: 5.291211967802182e-08
Residual: 5.679156538753961e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12859 357 1784]
Maximum f: 20.468863982698906
Maximum residual: 5.291524267451849e-08
Residual: 1.6532077313042398e-12
Load increment #49, load: 21.92105263157895, initial residual: 0.0012698506902833204
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12825 363 1811 1]
Maximum f: 20.949662773773305
Maximum residual: 6.744386502997123e-08
Residual: 0.0024192645022178116
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12819 362 1818 1]
Maximum f: 21.187953026114158
Maximum residual: 6.70764688806747e-08
Residual: 0.0010026947524404838
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12819 361 1819 1]
Maximum f: 21.221498101732124
Maximum residual: 5.3456395870395396e-08
Residual: 2.7592792073247533e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12819 361 1819 1]
Maximum f: 21.22153633814161
Maximum residual: 5.345395306362722e-08
Residual: 2.6870784238022427e-11
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12819 361 1819 1]
Maximum f: 21.22153633869432
Maximum residual: 5.345395200221594e-08
Residual: 5.849382027690772e-15
Load increment #50, load: 22.013157894736842, initial residual: 0.0012698506902865342
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12782 373 1845]
Maximum f: 21.67837726461898
Maximum residual: 6.911525124026282e-08
Residual: 0.003856399891474261
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12781 370 1849]
Maximum f: 22.04633301585368
Maximum residual: 1.1225358308091134e-07
Residual: 0.0002635286188395791
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12781 369 1850]
Maximum f: 22.05221164878433
Maximum residual: 4.453561630626364e-08
Residual: 8.556001831536681e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12781 369 1850]
Maximum f: 22.0522205460302
Maximum residual: 4.453418428451616e-08
Residual: 5.2450916262603946e-12
Load increment #51, load: 22.105263157894736, initial residual: 0.0012698506902737436
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12730 385 1885]
Maximum f: 22.506131836943226
Maximum residual: 4.9233417035411634e-08
Residual: 0.004752898321789831
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12728 378 1894]
Maximum f: 23.044392718537193
Maximum residual: 7.172581537267101e-08
Residual: 0.00030029348460328196
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12728 378 1894]
Maximum f: 23.057402668102302
Maximum residual: 7.056687998537865e-08
Residual: 6.047294841259013e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12728 378 1894]
Maximum f: 23.057407689999213
Maximum residual: 7.056558583876245e-08
Residual: 2.162567542423635e-12
Load increment #52, load: 22.19736842105263, initial residual: 0.001269850690288881
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12686 385 1929]
Maximum f: 23.562020905877148
Maximum residual: 5.7785744955179067e-08
Residual: 0.0032275214553501475
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12678 387 1935]
Maximum f: 24.036511762163595
Maximum residual: 6.887958434845373e-08
Residual: 0.0005947417167532816
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12678 384 1938]
Maximum f: 24.06453189307587
Maximum residual: 7.210941811522535e-08
Residual: 1.0602992761082734e-06
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12678 384 1938]
Maximum f: 24.064550796546122
Maximum residual: 7.21134133786955e-08
Residual: 4.401261778407899e-12
Load increment #53, load: 22.289473684210527, initial residual: 0.00126985069030174
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12633 404 1963]
Maximum f: 24.659637969985145
Maximum residual: 1.0416560114781882e-07
Residual: 0.003939447926218508
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12627 393 1980]
Maximum f: 25.276296615258023
Maximum residual: 6.210197204362519e-08
Residual: 0.0005439270780915825
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12627 392 1981]
Maximum f: 25.310041025221334
Maximum residual: 6.253081409605897e-08
Residual: 7.3568262705072e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12627 392 1981]
Maximum f: 25.31006063606528
Maximum residual: 6.253190475799704e-08
Residual: 1.7541647910980747e-12
Load increment #54, load: 22.38157894736842, initial residual: 0.0012698506902882735
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12580 398 2022]
Maximum f: 26.050578009521423
Maximum residual: 6.400990819911765e-08
Residual: 0.005078781536508987
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12574 394 2032]
Maximum f: 26.784959023396947
Maximum residual: 3.095099808386535e-07
Residual: 0.00043096334147957647
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12574 392 2034]
Maximum f: 26.79658273942866
Maximum residual: 3.447143690232908e-07
Residual: 1.888843913877108e-07
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12574 392 2034]
Maximum f: 26.796586912163175
Maximum residual: 3.4472801985876277e-07
Residual: 8.090258570992397e-14
Load increment #55, load: 22.473684210526315, initial residual: 0.001269850690287596
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12513 397 2090]
Maximum f: 27.393746853616477
Maximum residual: 8.071892375981295e-08
Residual: 0.004934508834150478
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12503 381 2116]
Maximum f: 28.52938102208324
Maximum residual: 3.6192799010497234e-07
Residual: 0.0015749954824767073
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12504 381 2115]
Maximum f: 28.602468189695887
Maximum residual: 5.6776056156284854e-08
Residual: 0.00035336556804717973
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12504 382 2114]
Maximum f: 28.604187936900917
Maximum residual: 9.689172228131553e-08
Residual: 4.7163974771225904e-08
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12504 382 2114]
Maximum f: 28.604187355413323
Maximum residual: 9.688937555990352e-08
Residual: 1.5081119417982764e-14
Load increment #56, load: 22.56578947368421, initial residual: 0.0012698506902877613
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12424 396 2180]
Maximum f: 30.201359183137697
Maximum residual: 2.6654542770652415e-07
Residual: 0.005528661784820465
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12418 378 2204]
Maximum f: 32.19780042696306
Maximum residual: 7.865053150917694e-08
Residual: 0.0028584418767026777
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12419 369 2212]
Maximum f: 32.4229424848732
Maximum residual: 5.13480670698914e-08
Residual: 0.0006593256004529714
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12419 370 2211]
Maximum f: 32.43888046553363
Maximum residual: 5.069790012936222e-08
Residual: 7.544842823333701e-07
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12419 370 2211]
Maximum f: 32.4388974236814
Maximum residual: 5.069738751526626e-08
Residual: 2.6058983513360408e-12
Load increment #57, load: 22.657894736842106, initial residual: 0.001269850690298214
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12293 378 2329]
Maximum f: 35.14424112546795
Maximum residual: 7.591360780905614e-08
Residual: 0.00835781824051834
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12292 332 2376]
Maximum f: 40.29081418782346
Maximum residual: 6.750305402849737e-08
Residual: 0.007395964608752622
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12283 336 2381]
Maximum f: 40.99539973153436
Maximum residual: 5.93774919782325e-08
Residual: 0.0008981157869348225
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12282 336 2382]
Maximum f: 41.08029012442486
Maximum residual: 5.8053317131289586e-08
Residual: 5.490400813635445e-05
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12282 334 2384]
Maximum f: 41.080657843086556
Maximum residual: 5.804855004645274e-08
Residual: 2.560650323841959e-08
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3]
Counts of unique number of iterations: [12282 334 2384]
Maximum f: 41.08065784768114
Maximum residual: 5.804855218901227e-08
Residual: 1.144072759355354e-14
Load increment #58, load: 22.75, initial residual: 0.0012698506902867688
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12119 353 2518 10]
Maximum f: 50.146690066294425
Maximum residual: 6.114497214651637e-08
Residual: 0.011255471181193993
Outer Newton iteration #1
Inner Newton summary:
Unique number of iterations: [1 2 3 4]
Counts of unique number of iterations: [12102 287 2559 52]
Maximum f: 75.49977673275367
Maximum residual: 1.2242135609962693e-07
Residual: 0.016138506943861108
Outer Newton iteration #2
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12094 241 2550 114 1]
Maximum f: 112.38889910971845
Maximum residual: 3.5877558054816494e-07
Residual: 0.010057270212342347
Outer Newton iteration #3
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12087 234 2500 175 4]
Maximum f: 126.18509352106626
Maximum residual: 4.6086029176779996e-07
Residual: 0.004692827733468935
Outer Newton iteration #4
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12084 227 2492 192 5]
Maximum f: 130.42945746142556
Maximum residual: 4.5556528308290085e-07
Residual: 0.00363939638471542
Outer Newton iteration #5
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12082 224 2485 204 5]
Maximum f: 131.97342320880097
Maximum residual: 4.4172690659165635e-07
Residual: 0.000888615530169231
Outer Newton iteration #6
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12082 222 2488 203 5]
Maximum f: 132.4327347467358
Maximum residual: 4.709657516081837e-07
Residual: 1.895961815676395e-05
Outer Newton iteration #7
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12082 222 2488 203 5]
Maximum f: 132.43732622485354
Maximum residual: 4.712954181128813e-07
Residual: 4.714596354059165e-09
Outer Newton iteration #8
Inner Newton summary:
Unique number of iterations: [1 2 3 4 5]
Counts of unique number of iterations: [12082 222 2488 203 5]
Maximum f: 132.43732751517913
Maximum residual: 4.712955327299751e-07
Residual: 2.7257746345557226e-14
Slope stability factor: 6.594202898550725
# 20 - critical load # -5.884057971014492
# Slope stability factor: -6.521739130434782
Verification#
Critical load#
if len(points_on_process) > 0:
plt.plot(-results[:, 0], results[:, 1], "o-")
plt.xlabel("Displacement of the slope at (0, H)")
plt.ylabel(r"Soil self-weight $\gamma$")
plt.savefig(f"displacement_rank{MPI.COMM_WORLD.rank:d}.png")
plt.show()
![../_images/cb992bf68c5717b776aadeeb38338edf442d3399e5fa4ad0eba48a66e9379481.png](../_images/cb992bf68c5717b776aadeeb38338edf442d3399e5fa4ad0eba48a66e9379481.png)
print(f"Slope stability factor for 2D plane strain factor [Chen]: {6.69}")
print(f"Computed slope stability factor: {22.75*H/c}")
Slope stability factor for 2D plane strain factor [Chen]: 6.69
Computed slope stability factor: 6.594202898550725
W = fem.functionspace(domain, ("Lagrange", 1, (gdim,)))
u_tmp = fem.Function(W, name="Displacement")
u_tmp.interpolate(u)
pyvista.start_xvfb()
plotter = pyvista.Plotter(window_size=[600, 400])
topology, cell_types, x = plot.vtk_mesh(domain)
grid = pyvista.UnstructuredGrid(topology, cell_types, x)
vals = np.zeros((x.shape[0], 3))
vals[:, : len(u_tmp)] = u_tmp.x.array.reshape((x.shape[0], len(u_tmp)))
grid["u"] = vals
warped = grid.warp_by_vector("u", factor=20)
plotter.add_text("Displacement field", font_size=11)
plotter.add_mesh(warped, show_edges=False, show_scalar_bar=True)
plotter.view_xy()
if not pyvista.OFF_SCREEN:
plotter.show()
error: XDG_RUNTIME_DIR is invalid or not set in the environment.
MESA: error: ZINK: failed to choose pdev
glx: failed to create drisw screen
![../_images/8a49985188b263e9ae1e2547e1bbdadbe0e2beb477fb60b8d0f46a771c5f5b5e.png](../_images/8a49985188b263e9ae1e2547e1bbdadbe0e2beb477fb60b8d0f46a771c5f5b5e.png)
Yield surface#
We verify that the constitutive model is correctly implemented by tracing the yield surface. We generate several stress paths and check whether they remain within the yield surface. The stress tracing is performed in the Haigh-Westergaard coordinates \((\xi, \rho, \theta)\) which are defined as follows
where \(J_3(\boldsymbol{\sigma}) = \det(\boldsymbol{s})\) is the third invariant of the deviatoric part of the stress tensor, \(\xi\) is the deviatoric coordinate, \(\rho\) is the radial coordinate and the angle \(\theta \in [-\frac{\pi}{6}, \frac{\pi}{6}]\) is called Lode or stress angle.
By introducing the hydrostatic variable \(p = \xi/\sqrt{3}\) The principal stresses can be written in Haigh-Westergaard coordinates
Firstly, we define and vectorize functions rho
, angle
and sigma_tracing
evaluating respectively the coordinates \(\rho\) and \(\theta\) and the corrected
stress tensor for a certain stress state.
def rho(sigma_local):
s = dev @ sigma_local
return jnp.sqrt(2.0 * J2(s))
def angle(sigma_local):
s = dev @ sigma_local
arg = -(3.0 * jnp.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2(s) * J2(s) * J2(s)))
arg = jnp.clip(arg, -1.0, 1.0)
angle = 1.0 / 3.0 * jnp.arcsin(arg)
return angle
def sigma_tracing(sigma_local, sigma_n_local):
deps_elas = S_elas @ sigma_local
sigma_corrected, state = return_mapping(deps_elas, sigma_n_local)
yielding = state[2]
return sigma_corrected, yielding
angle_v = jax.jit(jax.vmap(angle, in_axes=(0)))
rho_v = jax.jit(jax.vmap(rho, in_axes=(0)))
sigma_tracing_vec = jax.jit(jax.vmap(sigma_tracing, in_axes=(0, 0)))
Secondly, we generate a loading path by evaluating principal stresses through Haigh-Westergaard coordinates, where \(\rho\) and \(\xi\) are fixed ones.
N_angles = 200
N_loads = 10
eps = 1e-7
R = 0.7
p = 1.0
angle_values = np.linspace(0 + eps, 2 * np.pi - eps, N_angles)
dsigma_path = np.zeros((N_angles, stress_dim))
dsigma_path[:, 0] = np.sqrt(2.0 / 3.0) * R * np.cos(angle_values)
dsigma_path[:, 1] = np.sqrt(2.0 / 3.0) * R * np.sin(angle_values - np.pi / 6.0)
dsigma_path[:, 2] = np.sqrt(2.0 / 3.0) * R * np.sin(-angle_values - np.pi / 6.0)
angle_results = np.empty((N_loads, N_angles))
rho_results = np.empty((N_loads, N_angles))
sigma_results = np.empty((N_loads, N_angles, stress_dim))
sigma_n_local = np.zeros_like(dsigma_path)
sigma_n_local[:, 0] = p
sigma_n_local[:, 1] = p
sigma_n_local[:, 2] = p
derviatoric_axis = tr
print(f"rho = {R}, p = {p} - projection onto the octahedral plane\n")
for i in range(N_loads):
print(f"Loading#{i}")
dsigma, yielding = sigma_tracing_vec(dsigma_path, sigma_n_local)
dp = dsigma @ tr / 3.0 - p
dsigma -= np.outer(dp, derviatoric_axis) # projection on the same octahedral plane
sigma_results[i, :] = dsigma
angle_results[i, :] = angle_v(dsigma)
rho_results[i, :] = rho_v(dsigma)
print(f"max f: {jnp.max(yielding)}\n")
sigma_n_local[:] = dsigma
rho = 0.7, p = 1.0 - projection onto the octahedral plane
Loading#0
max f: -1.5556538807558753
Loading#1
max f: -1.1973911329550662
Loading#2
max f: -0.757988187228174
Loading#3
max f: -0.28548830756036025
Loading#4
max f: 0.20272750659611782
Loading#5
max f: 0.5549672091686113
Loading#6
max f: 0.6598541870035599
Loading#7
max f: 0.6908007847906061
Loading#8
max f: 0.6999099297010107
Loading#9
max f: 0.702589415869515
Finally, the stress paths are represented by a series of circles lying in each other in
the same octahedral plane. By applying the return-mapping algorithm defined in
the function return_mapping
, we perform the correction of the stress
paths. Once they get close to the elastic limit the traced curves look similar
to the Mohr-Coulomb yield surface with apex smoothing which indicates the
correct implementation of the constitutive model.
fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(8, 8))
for j in range(12):
for i in range(N_loads):
ax.plot(j * np.pi / 3 - j % 2 * angle_results[i] + (1 - j % 2) * angle_results[i], rho_results[i], ".")
ax.set_yticklabels([])
fig.tight_layout()
![../_images/243c957bf5ca6560796b39b90f649a2bd9a5e61c2a7a4bf883954513930b8fbd.png](../_images/243c957bf5ca6560796b39b90f649a2bd9a5e61c2a7a4bf883954513930b8fbd.png)
Taylor test#
The derivatives, on which the form \(F\) and its jacobian \(J\) are based, are automatically derived using the JAX AD tools. In this regards, we perform the Taylor test to ensure that these derivatives are computed correctly.
Indeed, by following the Taylor’s theorem and perturbating the functional \(F: V \to \mathbb{R}\) in the direction \(h \, \boldsymbol{δu} \in V\) for \(h > 0\), the first and second order Taylor reminders \(R_0\) and \(R_1\) have the following convergence rates
In the following code-blocks you may find the implementation of the Taylor test justifying the first and second convergence rates.
# Reset main variables to zero including the external operators values
sigma_n.x.array[:] = 0.0
sigma.ref_coefficient.x.array[:] = 0.0
J_external_operators[0].ref_coefficient.x.array[:] = 0.0
# Reset the values of the consistent tangent matrix to elastic moduli
Du.x.array[:] = 1.0
evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -2.2109628558449494
Maximum residual: 0.0
As the derivatives of the constitutive model are different for elastic and plastic phases, we must consider two initial states for the Taylor test. For this reason, we solve the problem once for a certain loading value to get the initial state close to the one with plastic deformations.
i = 0
load = 2.0
q.value = load * np.array([0, -gamma])
external_operator_problem.assemble_vector()
residual_0 = external_operator_problem.b.norm()
residual = residual_0
Du.x.array[:] = 0
if MPI.COMM_WORLD.rank == 0:
print(f"Load increment #{i}, load: {load}, initial residual: {residual_0}")
for iteration in range(0, max_iterations):
if residual / residual_0 < relative_tolerance:
break
if MPI.COMM_WORLD.rank == 0:
print(f"\tOuter Newton iteration #{iteration}")
external_operator_problem.assemble_matrix()
external_operator_problem.solve(du)
Du.vector.axpy(1.0, du.vector)
Du.x.scatter_forward()
evaluated_operands = evaluate_operands(F_external_operators)
((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
sigma.ref_coefficient.x.array[:] = sigma_new
external_operator_problem.assemble_vector()
residual = external_operator_problem.b.norm()
if MPI.COMM_WORLD.rank == 0:
print(f"\tResidual: {residual}\n")
sigma_n.x.array[:] = sigma.ref_coefficient.x.array
# Initial values of the displacement field and the stress state for the Taylor
# test
Du0 = np.copy(Du.x.array)
sigma_n0 = np.copy(sigma_n.x.array)
Load increment #0, load: 2.0, initial residual: 0.027573900703382316
Outer Newton iteration #0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.2983932974714918
Maximum residual: 0.0
Residual: 7.890919449890783e-14
If we take into account the initial stress state sigma_n0
computed in the cell
above, we perform the Taylor test for the plastic phase, otherwise we stay in
the elastic one.
h_list = np.logspace(-2.0, -6.0, 5)[::-1]
def perform_Taylor_test(Du0, sigma_n0):
# F(Du0 + h*δu) - F(Du0) - h*J(Du0)*δu
Du.x.array[:] = Du0
sigma_n.x.array[:] = sigma_n0
evaluated_operands = evaluate_operands(F_external_operators)
((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
sigma.ref_coefficient.x.array[:] = sigma_new
F0 = fem.petsc.assemble_vector(F_form) # F(Du0)
F0.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
J0 = fem.petsc.assemble_matrix(J_form)
J0.assemble() # J(Du0)
y = J0.createVecLeft() # y = J0 @ x
δu = fem.Function(V)
δu.x.array[:] = Du0 # δu == Du0
zero_order_remainder = np.zeros_like(h_list)
first_order_remainder = np.zeros_like(h_list)
for i, h in enumerate(h_list):
Du.x.array[:] = Du0 + h * δu.x.array
evaluated_operands = evaluate_operands(F_external_operators)
((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
sigma.ref_coefficient.x.array[:] = sigma_new
F_delta = fem.petsc.assemble_vector(F_form) # F(Du0 + h*δu)
F_delta.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
J0.mult(δu.vector, y) # y = J(Du0)*δu
y.scale(h) # y = h*y
zero_order_remainder[i] = (F_delta - F0).norm()
first_order_remainder[i] = (F_delta - F0 - y).norm()
return zero_order_remainder, first_order_remainder
print("Elastic phase")
zero_order_remainder_elastic, first_order_remainder_elastic = perform_Taylor_test(Du0, 0.0)
print("Plastic phase")
zero_order_remainder_plastic, first_order_remainder_plastic = perform_Taylor_test(Du0, sigma_n0)
Elastic phase
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.2983932974714918
Maximum residual: 0.0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.29839091658838823
Maximum residual: 0.0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.29836948862876245
Maximum residual: 0.0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.2981552078749816
Maximum residual: 0.0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.2960122846812845
Maximum residual: 0.0
Inner Newton summary:
Unique number of iterations: [1]
Counts of unique number of iterations: [15000]
Maximum f: -0.2745715837020426
Maximum residual: 0.0
Plastic phase
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14989 11]
Maximum f: 2.1523374160661706
Maximum residual: 5.235343387682854e-10
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14989 11]
Maximum f: 2.15233990398981
Maximum residual: 5.235366372814165e-10
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14989 11]
Maximum f: 2.152362295304328
Maximum residual: 5.235579917241567e-10
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14989 11]
Maximum f: 2.152586208624277
Maximum residual: 5.237749000378985e-10
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14989 11]
Maximum f: 2.154825359292723
Maximum residual: 5.259447116609796e-10
Inner Newton summary:
Unique number of iterations: [1 3]
Counts of unique number of iterations: [14989 11]
Maximum f: 2.1772186045752187
Maximum residual: 5.47923987245938e-10
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].loglog(h_list, zero_order_remainder_elastic, "o-", label=r"$R_0$")
axs[0].loglog(h_list, first_order_remainder_elastic, "o-", label=r"$R_1$")
annotation.slope_marker((5e-5, 5e-6), 1, ax=axs[0], poly_kwargs={"facecolor": "tab:blue"})
axs[1].loglog(h_list, zero_order_remainder_plastic, "o-", label=r"$R_0$")
annotation.slope_marker((5e-5, 5e-6), 1, ax=axs[1], poly_kwargs={"facecolor": "tab:blue"})
axs[1].loglog(h_list, first_order_remainder_plastic, "o-", label=r"$R_1$")
annotation.slope_marker((1e-4, 5e-13), 2, ax=axs[1], poly_kwargs={"facecolor": "tab:orange"})
for i in range(2):
axs[i].set_xlabel("h")
axs[i].set_ylabel("Taylor remainder")
axs[i].legend()
axs[i].grid()
plt.tight_layout()
first_order_rate = np.polyfit(np.log(h_list), np.log(zero_order_remainder_elastic), 1)[0]
second_order_rate = np.polyfit(np.log(h_list), np.log(first_order_remainder_elastic), 1)[0]
print(f"Elastic phase:\n\tthe 1st order rate = {first_order_rate:.2f}\n\tthe 2nd order rate = {second_order_rate:.2f}")
first_order_rate = np.polyfit(np.log(h_list), np.log(zero_order_remainder_plastic), 1)[0]
second_order_rate = np.polyfit(np.log(h_list[1:]), np.log(first_order_remainder_plastic[1:]), 1)[0]
print(f"Plastic phase:\n\tthe 1st order rate = {first_order_rate:.2f}\n\tthe 2nd order rate = {second_order_rate:.2f}")
Elastic phase:
the 1st order rate = 1.00
the 2nd order rate = 0.00
Plastic phase:
the 1st order rate = 1.00
the 2nd order rate = 1.92
![../_images/0ad023e7a7b7d36d52f3f626d329c0fd2447b347c578246132ee596c6eb161f0.png](../_images/0ad023e7a7b7d36d52f3f626d329c0fd2447b347c578246132ee596c6eb161f0.png)
For the elastic phase (on the left) the zeroth-order Taylor remainder \(R_0\) achieves the first-order convergence rate the same as for the plastic phase (on the right). The first-order remainder \(R_1\) is constant during the elastic response, as the jacobian is constant in this case contrarily to the plastic phase, where \(R_1\) has the second-order convergence.